From a6c04ca09462547c2bf2ddca01687510eaf7d6a4 Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Wed, 20 Jun 2018 14:50:54 -0700 Subject: [PATCH 1/9] thrift_proxy: simple thrift router Provides a very basic thrift router that can route to clusters based on method name only. A Thrift DecoderFilter interface is introduced, but the only available filter is the Router. The Network filter and router are capable of translating transports and protocols but presently cannot be configured to do so. Relates to #2247. *Risk Level*: low *Testing*: unit and integration testing *Docs Changes*: protobuf documentation updated *Release Notes*: introduced a basic thrift_proxy routing extension Signed-off-by: Stephan Zuercher --- .../network/thrift_proxy/v2alpha1/BUILD | 5 +- .../network/thrift_proxy/v2alpha1/route.proto | 42 + .../thrift_proxy/v2alpha1/router/BUILD | 8 + .../thrift_proxy/v2alpha1/router/router.proto | 9 + .../thrift_proxy/v2alpha1/thrift_proxy.proto | 44 ++ source/extensions/extensions_build_config.bzl | 6 + .../filters/network/thrift_proxy/BUILD | 84 +- .../thrift_proxy/app_exception_impl.cc | 34 + .../network/thrift_proxy/app_exception_impl.h | 44 ++ .../thrift_proxy/binary_protocol_impl.cc | 27 +- .../thrift_proxy/binary_protocol_impl.h | 7 +- .../network/thrift_proxy/buffer_helper.h | 47 -- .../thrift_proxy/compact_protocol_impl.cc | 19 +- .../thrift_proxy/compact_protocol_impl.h | 5 +- .../filters/network/thrift_proxy/config.cc | 114 ++- .../filters/network/thrift_proxy/config.h | 41 + .../network/thrift_proxy/conn_manager.cc | 293 +++++++ .../network/thrift_proxy/conn_manager.h | 252 ++++++ .../filters/network/thrift_proxy/decoder.cc | 254 +++--- .../filters/network/thrift_proxy/decoder.h | 92 ++- .../filters/network/thrift_proxy/filter.cc | 310 -------- .../filters/network/thrift_proxy/filter.h | 173 ----- .../network/thrift_proxy/filters/BUILD | 50 ++ .../thrift_proxy/filters/factory_base.h | 45 ++ .../network/thrift_proxy/filters/filter.h | 306 ++++++++ .../thrift_proxy/filters/filter_config.h | 55 ++ .../thrift_proxy/filters/well_known_names.h | 25 + .../thrift_proxy/framed_transport_impl.cc | 29 +- .../thrift_proxy/framed_transport_impl.h | 7 +- .../filters/network/thrift_proxy/protocol.h | 127 +-- .../network/thrift_proxy/protocol_converter.h | 144 ++++ .../network/thrift_proxy/protocol_impl.cc | 14 +- .../network/thrift_proxy/protocol_impl.h | 33 +- .../filters/network/thrift_proxy/router/BUILD | 51 ++ .../network/thrift_proxy/router/config.cc | 34 + .../network/thrift_proxy/router/config.h | 32 + .../network/thrift_proxy/router/router.h | 62 ++ .../thrift_proxy/router/router_impl.cc | 280 +++++++ .../network/thrift_proxy/router/router_impl.h | 157 ++++ .../filters/network/thrift_proxy/stats.h | 52 ++ .../filters/network/thrift_proxy/transport.h | 112 ++- .../network/thrift_proxy/transport_impl.cc | 18 +- .../network/thrift_proxy/transport_impl.h | 28 +- .../thrift_proxy/unframed_transport_impl.cc | 22 + .../thrift_proxy/unframed_transport_impl.h | 14 +- .../filters/network/thrift_proxy/BUILD | 56 +- .../thrift_proxy/binary_protocol_impl_test.cc | 134 +--- .../thrift_proxy/buffer_helper_test.cc | 44 -- .../compact_protocol_impl_test.cc | 157 ++-- .../network/thrift_proxy/config_test.cc | 4 +- .../network/thrift_proxy/conn_manager_test.cc | 723 ++++++++++++++++++ .../network/thrift_proxy/decoder_test.cc | 396 ++++++++-- .../network/thrift_proxy/filter_test.cc | 559 -------------- .../framed_transport_impl_test.cc | 54 +- ...ntegration_test.cc => integration_test.cc} | 41 +- .../filters/network/thrift_proxy/mocks.cc | 66 +- .../filters/network/thrift_proxy/mocks.h | 132 +++- .../thrift_proxy/protocol_impl_test.cc | 35 +- .../network/thrift_proxy/router_test.cc | 596 +++++++++++++++ .../thrift_proxy/transport_impl_test.cc | 83 +- .../unframed_transport_impl_test.cc | 32 +- test/mocks/tcp/BUILD | 1 + test/mocks/tcp/mocks.cc | 8 +- test/mocks/tcp/mocks.h | 16 + 64 files changed, 4876 insertions(+), 1868 deletions(-) create mode 100644 api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto create mode 100644 api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD create mode 100644 api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto create mode 100644 source/extensions/filters/network/thrift_proxy/app_exception_impl.cc create mode 100644 source/extensions/filters/network/thrift_proxy/app_exception_impl.h create mode 100644 source/extensions/filters/network/thrift_proxy/conn_manager.cc create mode 100644 source/extensions/filters/network/thrift_proxy/conn_manager.h delete mode 100644 source/extensions/filters/network/thrift_proxy/filter.cc delete mode 100644 source/extensions/filters/network/thrift_proxy/filter.h create mode 100644 source/extensions/filters/network/thrift_proxy/filters/BUILD create mode 100644 source/extensions/filters/network/thrift_proxy/filters/factory_base.h create mode 100644 source/extensions/filters/network/thrift_proxy/filters/filter.h create mode 100644 source/extensions/filters/network/thrift_proxy/filters/filter_config.h create mode 100644 source/extensions/filters/network/thrift_proxy/filters/well_known_names.h create mode 100644 source/extensions/filters/network/thrift_proxy/protocol_converter.h create mode 100644 source/extensions/filters/network/thrift_proxy/router/BUILD create mode 100644 source/extensions/filters/network/thrift_proxy/router/config.cc create mode 100644 source/extensions/filters/network/thrift_proxy/router/config.h create mode 100644 source/extensions/filters/network/thrift_proxy/router/router.h create mode 100644 source/extensions/filters/network/thrift_proxy/router/router_impl.cc create mode 100644 source/extensions/filters/network/thrift_proxy/router/router_impl.h create mode 100644 source/extensions/filters/network/thrift_proxy/stats.h create mode 100644 source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc create mode 100644 test/extensions/filters/network/thrift_proxy/conn_manager_test.cc delete mode 100644 test/extensions/filters/network/thrift_proxy/filter_test.cc rename test/extensions/filters/network/thrift_proxy/{filter_integration_test.cc => integration_test.cc} (88%) create mode 100644 test/extensions/filters/network/thrift_proxy/router_test.cc diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD index 60540cd8079a4..da39334babd17 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD @@ -4,5 +4,8 @@ licenses(["notice"]) # Apache 2 api_proto_library_internal( name = "thrift_proxy", - srcs = ["thrift_proxy.proto"], + srcs = [ + "route.proto", + "thrift_proxy.proto", + ], ) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto new file mode 100644 index 0000000000000..1d78b99647da9 --- /dev/null +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +package envoy.extensions.filters.network.thrift_proxy.v2alpha1; +option go_package = "v2"; + +import "validate/validate.proto"; +import "gogoproto/gogo.proto"; + +// [#protodoc-title: Thrift route configuration] + +// [#comment:next free field: 3] +message RouteConfiguration { + // The name of the route configuration. Reserved for future use in asynchronous route discovery. + string name = 1; + + // The list of routes that will be matched, in order, against incoming requests. The first route + // that matches will be used. + repeated Route routes = 2 [(gogoproto.nullable) = false]; +} + +// [#comment:next free field: 3] +message Route { + // Route matching prarameters. + RouteMatch match = 1 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; + + // Route request to some upstream cluster. + RouteAction route = 2 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; + ; +} + +// [#comment:next free field: 2] +message RouteMatch { + // If specified, the route must exactly match the request method name. As a special case, an + // empty string matches any request method name. + string method = 1; +} + +// [#comment:next free field: 2] +message RouteAction { + // Indicates the upstream cluster to which the request should be routed. + string cluster = 1 [(validate.rules).string.min_bytes = 1]; +} diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD new file mode 100644 index 0000000000000..ce0ad0e254f03 --- /dev/null +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD @@ -0,0 +1,8 @@ +load("//bazel:api_build_system.bzl", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "router", + srcs = ["router.proto"], +) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto new file mode 100644 index 0000000000000..2818c044e7b07 --- /dev/null +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package envoy.extensions.filters.network.thrift_proxy.v2alpha1.router; +option go_package = "router"; + +// [#protodoc-title: Thrift Router] +// Thrift Router configuration. +message Router { +} diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto index e2d6bd02cb261..8897292273202 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto @@ -3,11 +3,55 @@ syntax = "proto3"; package envoy.extensions.filters.network.thrift_proxy.v2alpha1; option go_package = "v2"; +import "envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto"; + import "validate/validate.proto"; +import "gogoproto/gogo.proto"; // [#protodoc-title: Extensions Thrift Proxy] // Thrift Proxy filter configuration. +// [#comment:next free field: 5] message ThriftProxy { + enum TransportType { + option (gogoproto.goproto_enum_prefix) = false; + + // For every new connection, the Thrift proxy will determine which transport to use. + AUTO_TRANSPORT = 0; + + // The Thrift proxy will assume the client is using the Thrift framed transport. + FRAMED = 1; + + // The Thrift proxy will assume the client is using the Thrift unframed transport. + UNFRAMED = 2; + } + + // Supplies the type of transport that the Thrift proxy should use. Defaults to `AUTO_TRANSPORT`. + TransportType transport = 2 [(validate.rules).enum.defined_only = true]; + + enum ProtocolType { + option (gogoproto.goproto_enum_prefix) = false; + + // For every new connection, the Thrift proxy will determine which protocol to use. + // N.B. The older, non-strict binary protocol is not included in automatic protocol + // detection. + AUTO_PROTOCOL = 0; + + // The Thrift proxy will assume the client is using the Thrift binary protocol. + BINARY = 1; + + // The Thrift proxy will assume the client is using the Thrift non-strict binary protocol. + LAX_BINARY = 2; + + // The Thrift proxy will assume the client is using the Thrift compact protocol. + COMPACT = 3; + } + + // Supplies the type of protocol that the Thrift proxy should use. Defaults to `AUTO_PROTOCOL`. + ProtocolType protocol = 3 [(validate.rules).enum.defined_only = true]; + // The human readable prefix to use when emitting statistics. string stat_prefix = 1 [(validate.rules).string.min_bytes = 1]; + + // The route table for the connection manager is static and is specified in this property. + RouteConfiguration route_config = 4; } diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index 04b1d40a6adb0..42a2e194c8f32 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -79,6 +79,12 @@ EXTENSIONS = { "envoy.stat_sinks.metrics_service": "//source/extensions/stat_sinks/metrics_service:config", "envoy.stat_sinks.statsd": "//source/extensions/stat_sinks/statsd:config", + # + # Thrift filters + # + + "envoy.filters.thrift.router": "//source/extensions/filters/network/thrift_proxy/router:config", + # # Tracers # diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 48555347641b0..c69373d53e375 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -8,6 +8,17 @@ load( envoy_package() +envoy_cc_library( + name = "app_exception_lib", + srcs = ["app_exception_impl.cc"], + hdrs = ["app_exception_impl.h"], + deps = [ + ":protocol_interface", + "//include/envoy/buffer:buffer_interface", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "buffer_helper_lib", srcs = ["buffer_helper.cc"], @@ -24,41 +35,68 @@ envoy_cc_library( srcs = ["config.cc"], hdrs = ["config.h"], deps = [ - ":filter_lib", + ":conn_manager_lib", + ":decoder_lib", + ":protocol_lib", "//include/envoy/registry", - "//source/common/config:filter_json_lib", + "//source/common/config:utility_lib", "//source/extensions/filters/network:well_known_names", "//source/extensions/filters/network/common:factory_base_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_config_interface", + "//source/extensions/filters/network/thrift_proxy/filters:well_known_names", + "//source/extensions/filters/network/thrift_proxy/router:router_lib", "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1:thrift_proxy_cc", ], ) +envoy_cc_library( + name = "conn_manager_lib", + srcs = ["conn_manager.cc"], + hdrs = ["conn_manager.h"], + deps = [ + ":app_exception_lib", + ":decoder_lib", + ":protocol_converter_lib", + ":protocol_lib", + ":stats_lib", + ":transport_lib", + "//include/envoy/event:deferred_deletable", + "//include/envoy/event:dispatcher_interface", + "//include/envoy/network:connection_interface", + "//include/envoy/network:filter_interface", + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:timespan", + "//source/common/buffer:buffer_lib", + "//source/common/common:assert_lib", + "//source/common/common:linked_object", + "//source/common/common:logger_lib", + "//source/common/network:filter_lib", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + ], +) + envoy_cc_library( name = "decoder_lib", srcs = ["decoder.cc"], hdrs = ["decoder.h"], deps = [ ":protocol_lib", + ":stats_lib", ":transport_lib", "//source/common/buffer:buffer_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) envoy_cc_library( - name = "filter_lib", - srcs = ["filter.cc"], - hdrs = ["filter.h"], + name = "protocol_converter_lib", + hdrs = [ + "protocol_converter.h", + ], deps = [ - ":decoder_lib", - "//include/envoy/network:connection_interface", - "//include/envoy/network:filter_interface", - "//include/envoy/stats:stats_interface", - "//include/envoy/stats:stats_macros", - "//include/envoy/stats:timespan", - "//source/common/buffer:buffer_lib", - "//source/common/common:assert_lib", - "//source/common/common:logger_lib", - "//source/common/network:filter_lib", + ":protocol_interface", + "//include/envoy/buffer:buffer_interface", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) @@ -70,6 +108,9 @@ envoy_cc_library( external_deps = ["abseil_optional"], deps = [ "//include/envoy/buffer:buffer_interface", + "//include/envoy/registry", + "//source/common/common:assert_lib", + "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", ], ) @@ -94,12 +135,24 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "stats_lib", + hdrs = ["stats.h"], + deps = [ + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:stats_macros", + ], +) + envoy_cc_library( name = "transport_interface", hdrs = ["transport.h"], external_deps = ["abseil_optional"], deps = [ "//include/envoy/buffer:buffer_interface", + "//include/envoy/registry", + "//source/common/common:assert_lib", + "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", ], ) @@ -109,6 +162,7 @@ envoy_cc_library( srcs = [ "framed_transport_impl.cc", "transport_impl.cc", + "unframed_transport_impl.cc", ], hdrs = [ "framed_transport_impl.h", diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc new file mode 100644 index 0000000000000..65455c12b3609 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc @@ -0,0 +1,34 @@ +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +static const std::string TApplicationException = "TApplicationException"; +static const std::string MessageField = "message"; +static const std::string TypeField = "type"; +static const std::string StopField = ""; + +void AppException::encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) { + proto.writeMessageBegin(buffer, method_name_, ThriftProxy::MessageType::Exception, seq_id_); + proto.writeStructBegin(buffer, TApplicationException); + + proto.writeFieldBegin(buffer, MessageField, ThriftProxy::FieldType::String, 1); + proto.writeString(buffer, error_message_); + proto.writeFieldEnd(buffer); + + proto.writeFieldBegin(buffer, TypeField, ThriftProxy::FieldType::I32, 2); + proto.writeInt32(buffer, static_cast(type_)); + proto.writeFieldEnd(buffer); + + proto.writeFieldBegin(buffer, StopField, ThriftProxy::FieldType::Stop, 0); + + proto.writeStructEnd(buffer); + proto.writeMessageEnd(buffer); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.h b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h new file mode 100644 index 0000000000000..4a0335704100a --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h @@ -0,0 +1,44 @@ +#pragma once + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Thrift Application Exception types. + * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + */ +enum class AppExceptionType { + Unknown = 0, + UnknownMethod = 1, + InvalidMessageType = 2, + WrongMethodName = 3, + BadSequenceId = 4, + MissingResult = 5, + InternalError = 6, + ProtocolError = 7, + InvalidTransform = 8, + InvalidProtocol = 9, + UnsupportedClientType = 10, +}; + +struct AppException : public ThriftFilters::DirectResponse { + AppException(const absl::string_view method_name, int32_t seq_id, AppExceptionType type, + const std::string& error_message) + : method_name_(method_name), seq_id_(seq_id), type_(type), error_message_(error_message) {} + + void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) override; + + const std::string method_name_; + const int32_t seq_id_; + const AppExceptionType type_; + const std::string error_message_; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc index 3fc1038962e8a..ee5a734eda209 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc @@ -60,26 +60,22 @@ bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& msg_type = type; seq_id = BufferHelper::drainI32(buffer); - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } bool BinaryProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onMessageComplete(); return true; } bool BinaryProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& name) { UNREFERENCED_PARAMETER(buffer); name.clear(); // binary protocol does not transmit struct names - onStructBegin(absl::string_view(name)); return true; } bool BinaryProtocolImpl::readStructEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onStructEnd(); return true; } @@ -110,7 +106,6 @@ bool BinaryProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& n name.clear(); // binary protocol does not transmit field names field_type = type; - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -402,7 +397,6 @@ bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::stri seq_id = BufferHelper::peekI32(buffer, 1); buffer.drain(5); - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } @@ -413,6 +407,27 @@ void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const st BufferHelper::writeI32(buffer, seq_id); } +class BinaryProtocolConfigFactory : public ProtocolFactoryBase { +public: + BinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().BINARY) {} +}; + +/** + * Static registration for the binary protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + +class LaxBinaryProtocolConfigFactory : public ProtocolFactoryBase { +public: + LaxBinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().LAX_BINARY) {} +}; + +/** + * Static registration for the auto protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_lax_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h index fac781ea563b9..e292d6cd036b9 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h @@ -16,12 +16,13 @@ namespace ThriftProxy { * BinaryProtocolImpl implements the Thrift Binary protocol with strict message encoding. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md */ -class BinaryProtocolImpl : public ProtocolImplBase { +class BinaryProtocolImpl : public Protocol { public: - BinaryProtocolImpl(ProtocolCallbacks& callbacks) : ProtocolImplBase(callbacks) {} + BinaryProtocolImpl() {} // Protocol const std::string& name() const override { return ProtocolNames::get().BINARY; } + ProtocolType type() const override { return ProtocolType::Binary; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; @@ -81,7 +82,7 @@ class BinaryProtocolImpl : public ProtocolImplBase { */ class LaxBinaryProtocolImpl : public BinaryProtocolImpl { public: - LaxBinaryProtocolImpl(ProtocolCallbacks& callbacks) : BinaryProtocolImpl(callbacks) {} + LaxBinaryProtocolImpl() {} const std::string& name() const override { return ProtocolNames::get().LAX_BINARY; } diff --git a/source/extensions/filters/network/thrift_proxy/buffer_helper.h b/source/extensions/filters/network/thrift_proxy/buffer_helper.h index 1689d0849df57..c4945cd5da6b5 100644 --- a/source/extensions/filters/network/thrift_proxy/buffer_helper.h +++ b/source/extensions/filters/network/thrift_proxy/buffer_helper.h @@ -10,53 +10,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/** - * BufferWrapper provides a partial implementation of Buffer::Instance that is sufficient for - * BufferHelper to read Thrift protocol data without draining the buffer's contents. - */ -class BufferWrapper : public Buffer::Instance { -public: - BufferWrapper(Buffer::Instance& underlying) : underlying_(underlying) {} - - uint64_t position() { return position_; } - - // Buffer::Instance - void copyOut(size_t start, uint64_t size, void* data) const override { - ASSERT(position_ + start + size <= underlying_.length()); - underlying_.copyOut(start + position_, size, data); - } - void drain(uint64_t size) override { - ASSERT(position_ + size <= underlying_.length()); - position_ += size; - } - uint64_t length() const override { - ASSERT(underlying_.length() >= position_); - return underlying_.length() - position_; - } - void* linearize(uint32_t size) override { - ASSERT(position_ + size <= underlying_.length()); - uint8_t* p = static_cast(underlying_.linearize(position_ + size)); - return p + position_; - } - void add(const void*, uint64_t) override { NOT_IMPLEMENTED; } - void addBufferFragment(Buffer::BufferFragment&) override { NOT_IMPLEMENTED; } - void add(const std::string&) override { NOT_IMPLEMENTED; } - void add(const Buffer::Instance&) override { NOT_IMPLEMENTED; } - void commit(Buffer::RawSlice*, uint64_t) override { NOT_IMPLEMENTED; } - uint64_t getRawSlices(Buffer::RawSlice*, uint64_t) const override { NOT_IMPLEMENTED; } - void move(Buffer::Instance&) override { NOT_IMPLEMENTED; } - void move(Buffer::Instance&, uint64_t) override { NOT_IMPLEMENTED; } - int read(int, uint64_t) override { NOT_IMPLEMENTED; } - uint64_t reserve(uint64_t, Buffer::RawSlice*, uint64_t) override { NOT_IMPLEMENTED; } - ssize_t search(const void*, uint64_t, size_t) const override { NOT_IMPLEMENTED; } - int write(int) override { NOT_IMPLEMENTED; } - std::string toString() const override { NOT_IMPLEMENTED; } - -private: - Buffer::Instance& underlying_; - uint64_t position_{0}; -}; - /** * BufferHelper provides buffer operations for reading bytes and numbers in the various encodings * used by Thrift protocols. diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc index 5ea3a3cafb483..417a80d8b6197 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc @@ -72,13 +72,11 @@ bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string msg_type = type; seq_id = id; - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } bool CompactProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onMessageComplete(); return true; } @@ -91,7 +89,6 @@ bool CompactProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& last_field_id_stack_.push(last_field_id_); last_field_id_ = 0; - onStructBegin(absl::string_view(name)); return true; } @@ -105,7 +102,6 @@ bool CompactProtocolImpl::readStructEnd(Buffer::Instance& buffer) { last_field_id_ = last_field_id_stack_.top(); last_field_id_stack_.pop(); - onStructEnd(); return true; } @@ -124,7 +120,6 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& field_type = FieldType::Stop; buffer.drain(1); - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -166,7 +161,6 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& buffer.drain(id_size + 1); - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -459,7 +453,7 @@ void CompactProtocolImpl::writeFieldBeginInternal( static_cast(compact_field_type)); } else { BufferHelper::writeI8(buffer, static_cast(compact_field_type)); - BufferHelper::writeI16(buffer, field_id); + BufferHelper::writeZigZagI32(buffer, static_cast(field_id)); } last_field_id_ = field_id; @@ -623,6 +617,17 @@ CompactProtocolImpl::CompactFieldType CompactProtocolImpl::convertFieldType(Fiel } } +class CompactProtocolConfigFactory : public ProtocolFactoryBase { +public: + CompactProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().COMPACT) {} +}; + +/** + * Static registration for the binary protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h index 72bad6dee4482..322d03a3a83da 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h @@ -19,12 +19,13 @@ namespace ThriftProxy { * CompactProtocolImpl implements the Thrift Compact protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md */ -class CompactProtocolImpl : public ProtocolImplBase { +class CompactProtocolImpl : public Protocol { public: - CompactProtocolImpl(ProtocolCallbacks& callbacks) : ProtocolImplBase(callbacks) {} + CompactProtocolImpl() {} // Protocol const std::string& name() const override { return ProtocolNames::get().COMPACT; } + ProtocolType type() const override { return ProtocolType::Compact; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index 49f169eee4815..aa7201f483ec6 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -1,26 +1,81 @@ #include "extensions/filters/network/thrift_proxy/config.h" +#include #include #include "envoy/network/connection.h" #include "envoy/registry/registry.h" -#include "extensions/filters/network/thrift_proxy/filter.h" +#include "common/config/utility.h" + +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter_config.h" +#include "extensions/filters/network/thrift_proxy/filters/well_known_names.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/stats.h" +#include "extensions/filters/network/thrift_proxy/transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +namespace { + +typedef std::map< + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType, + TransportType> + TransportTypeMap; + +static const TransportTypeMap& transportTypeMap() { + CONSTRUCT_ON_FIRST_USE(TransportTypeMap, + { + {envoy::extensions::filters::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_AUTO_TRANSPORT, + TransportType::Auto}, + {envoy::extensions::filters::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_FRAMED, + TransportType::Framed}, + {envoy::extensions::filters::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_UNFRAMED, + TransportType::Unframed}, + }); +} + +typedef std::map< + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType, + ProtocolType> + ProtocolTypeMap; + +static const ProtocolTypeMap& protocolTypeMap() { + CONSTRUCT_ON_FIRST_USE(ProtocolTypeMap, { + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_AUTO_PROTOCOL, + ProtocolType::Auto}, + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_BINARY, + ProtocolType::Binary}, + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_LAX_BINARY, + ProtocolType::LaxBinary}, + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_COMPACT, + ProtocolType::Compact}, + }); +} + +} // namespace Network::FilterFactoryCb ThriftProxyFilterConfigFactory::createFilterFactoryFromProtoTyped( const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& proto_config, Server::Configuration::FactoryContext& context) { - ASSERT(!proto_config.stat_prefix().empty()); - - const std::string stat_prefix = fmt::format("thrift.{}.", proto_config.stat_prefix()); + std::shared_ptr filter_config(new ConfigImpl(proto_config, context)); - return [stat_prefix, &context](Network::FilterManager& filter_manager) -> void { - filter_manager.addFilter(std::make_shared(stat_prefix, context.scope())); + return [filter_config](Network::FilterManager& filter_manager) -> void { + filter_manager.addReadFilter(std::make_shared(*filter_config)); }; } @@ -31,6 +86,53 @@ static Registry::RegisterFactory registered_; +ConfigImpl::ConfigImpl( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& config, + Server::Configuration::FactoryContext& context) + : context_(context), stats_prefix_(fmt::format("thrift.{}.", config.stat_prefix())), + stats_(ThriftFilterStats::generateStats(stats_prefix_, context_.scope())), + transport_(config.transport()), proto_(config.protocol()), + route_matcher_(new Router::RouteMatcher(config.route_config())) { + + // Construct the only Thrift DecoderFilter: the Router + auto& factory = + Envoy::Config::Utility::getAndCheckFactory( + ThriftFilters::ThriftFilterNames::get().ROUTER); + ThriftFilters::FilterFactoryCb callback; + + auto empty_config = factory.createEmptyConfigProto(); + callback = factory.createFilterFactoryFromProto(*empty_config, stats_prefix_, context_); + filter_factories_.push_back(callback); +} + +void ConfigImpl::createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) { + for (const ThriftFilters::FilterFactoryCb& factory : filter_factories_) { + factory(callbacks); + } +} + +DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { + return std::make_unique(createTransport(), createProtocol(), callbacks); +} + +TransportPtr ConfigImpl::createTransport() { + TransportTypeMap::const_iterator i = transportTypeMap().find(transport_); + if (i == transportTypeMap().end()) { + throw new EnvoyException(fmt::format("unknown transport type: {}", transport_)); + } + + return NamedTransportConfigFactory::getFactory(i->second).createTransport(); +} + +ProtocolPtr ConfigImpl::createProtocol() { + ProtocolTypeMap::const_iterator i = protocolTypeMap().find(proto_); + if (i == protocolTypeMap().end()) { + throw new EnvoyException(fmt::format("unknown protocol type: {}", proto_)); + } + + return NamedProtocolConfigFactory::getFactory(i->second).createProtocol(); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index d578bd96591ac..265f80f57ab23 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -1,11 +1,16 @@ #pragma once +#include #include #include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" #include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.validate.h" +#include "envoy/stats/stats.h" #include "extensions/filters/network/common/factory_base.h" +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" #include "extensions/filters/network/well_known_names.h" namespace Envoy { @@ -28,6 +33,42 @@ class ThriftProxyFilterConfigFactory Server::Configuration::FactoryContext& context) override; }; +class ConfigImpl : public Config, + public Router::Config, + public ThriftFilters::FilterChainFactory, + Logger::Loggable { +public: + ConfigImpl(const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& config, + Server::Configuration::FactoryContext& context); + + // ThriftFilters::FilterChainFactory + void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override; + + // Router::Config + Router::RouteConstSharedPtr route(const std::string& method_name) const override { + return route_matcher_->route(method_name); + } + + // Config: + ThriftFilterStats& stats() override { return stats_; } + ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } + DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; + Router::Config& routerConfig() override { return *this; } + +private: + TransportPtr createTransport(); + ProtocolPtr createProtocol(); + + Server::Configuration::FactoryContext& context_; + const std::string stats_prefix_; + ThriftFilterStats stats_; + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType transport_; + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType proto_; + std::unique_ptr route_matcher_; + + std::list filter_factories_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc new file mode 100644 index 0000000000000..ff066564ce620 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -0,0 +1,293 @@ +#include "extensions/filters/network/thrift_proxy/conn_manager.h" + +#include "envoy/common/exception.h" +#include "envoy/event/dispatcher.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +ConnectionManager::ConnectionManager(Config& config) + : config_(config), stats_(config_.stats()), decoder_(config_.createDecoder(*this)) {} + +ConnectionManager::~ConnectionManager() {} + +Network::FilterStatus ConnectionManager::onData(Buffer::Instance& data, bool end_stream) { + UNREFERENCED_PARAMETER(end_stream); + + request_buffer_.move(data); + dispatch(); + + return Network::FilterStatus::StopIteration; +} + +void ConnectionManager::dispatch() { + if (stopped_) { + ENVOY_LOG(error, "thrift filter stopped"); + return; + } + + try { + bool underflow = false; + while (!underflow) { + ThriftFilters::FilterStatus status = decoder_->onData(request_buffer_, underflow); + if (status == ThriftFilters::FilterStatus::StopIteration) { + stopped_ = true; + break; + } + } + } catch (const EnvoyException& ex) { + ENVOY_LOG(error, "thrift error: {}", ex.what()); + stats_.request_decoding_error_.inc(); + + // Use the current rpc to send an error downstream, if possible. + rpcs_.front()->onError(ex.what()); + + resetAllRpcs(); + } +} + +void ConnectionManager::continueDecoding() { + stopped_ = false; + dispatch(); +} + +void ConnectionManager::doDeferredRpcDestroy(ConnectionManager::ActiveRpc& rpc) { + read_callbacks_->connection().dispatcher().deferredDelete(rpc.removeFromList(rpcs_)); +} + +void ConnectionManager::resetAllRpcs() { + while (!rpcs_.empty()) { + rpcs_.front()->onReset(); + } + + read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); +} + +void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) { + read_callbacks_ = &callbacks; +} + +void ConnectionManager::onEvent(Network::ConnectionEvent event) { + if (rpcs_.empty()) { + return; + } + + if (event == Network::ConnectionEvent::RemoteClose) { + stats_.cx_destroy_local_with_active_rq_.inc(); + } + + if (event == Network::ConnectionEvent::LocalClose) { + stats_.cx_destroy_remote_with_active_rq_.inc(); + } +} + +ThriftFilters::DecoderFilter& ConnectionManager::newDecoderFilter() { + ENVOY_LOG(debug, "new decoder filter"); + + ActiveRpcPtr new_rpc(new ActiveRpc(*this)); + new_rpc->createFilterChain(); + new_rpc->moveIntoList(std::move(new_rpc), rpcs_); + + return **rpcs_.begin(); +} + +bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { + upstream_buffer_.move(data); + + bool underflow = false; + decoder_->onData(upstream_buffer_, underflow); + ASSERT(complete_ || underflow); + return complete_; +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::messageBegin(absl::string_view name, + MessageType msg_type, + int32_t seq_id) { + reply_.emplace(std::string(name), msg_type, seq_id); + first_reply_field_ = (msg_type == MessageType::Reply); + return ProtocolConverter::messageBegin(name, msg_type, seq_id); +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl::string_view name, + FieldType field_type, + int16_t field_id) { + if (first_reply_field_) { + // Reply messages contain a struct where field 0 is the call result and fields 1+ are + // exceptions, if defined. At most one field may be set. Therefore, the very first field we + // encounter in a reply is either field 0 (success) or not (IDL exception returned). + ASSERT(reply_.has_value()); + reply_.value().success_ = field_id == 0 && field_type != FieldType::Stop; + first_reply_field_ = false; + } + + return ProtocolConverter::fieldBegin(name, field_type, field_id); +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { + ConnectionManager& cm = parent_.parent_; + + Buffer::OwnedImpl buffer; + + // Use the factory to get the concrete transport from the decoder transport (as opposed to + // potentially pre-detection auto transport). + TransportPtr transport = + NamedTransportConfigFactory::getFactory(parent_.parent_.decoder_->transportType()) + .createTransport(); + transport->encodeFrame(buffer, parent_.response_buffer_); + complete_ = true; + + cm.read_callbacks_->connection().write(buffer, false); + + cm.stats_.response_.inc(); + + ASSERT(reply_.has_value()); + switch (reply_.value().msg_type_) { + case MessageType::Reply: + cm.stats_.response_reply_.inc(); + if (reply_.value().success_.value_or(false)) { + cm.stats_.response_success_.inc(); + } else { + cm.stats_.response_error_.inc(); + } + + break; + + case MessageType::Exception: + cm.stats_.response_exception_.inc(); + break; + + default: + cm.stats_.response_invalid_type_.inc(); + break; + } + + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus ConnectionManager::ActiveRpc::transportEnd() { + ASSERT(call_.has_value()); + + parent_.stats_.request_.inc(); + + switch (call_.value().msg_type_) { + case MessageType::Call: + parent_.stats_.request_call_.inc(); + break; + + case MessageType::Oneway: + parent_.stats_.request_oneway_.inc(); + + // No response forthcoming, we're done. + parent_.doDeferredRpcDestroy(*this); + break; + + default: + parent_.stats_.request_invalid_type_.inc(); + break; + } + + return decoder_filter_->transportEnd(); +} + +void ConnectionManager::ActiveRpc::createFilterChain() { + parent_.config_.filterFactory().createFilterChain(*this); +} + +void ConnectionManager::ActiveRpc::onReset() { + // TODO(zuercer): e.g., parent_.stats_.named_.downstream_rq_rx_reset_.inc(); + parent_.doDeferredRpcDestroy(*this); +} + +void ConnectionManager::ActiveRpc::onError(const std::string& what) { + if (call_.has_value()) { + const Message& msg = call_.value(); + sendLocalReply(std::make_unique(msg.method_name_, msg.seq_id_, + AppExceptionType::ProtocolError, what)); + return; + } + + // Transport or protocol error happened before (or during message begin) parsing. It's not + // possible to provide a valid response, so don't try. +} + +const Network::Connection* ConnectionManager::ActiveRpc::connection() { + return &parent_.read_callbacks_->connection(); +} + +void ConnectionManager::ActiveRpc::continueDecoding() { parent_.continueDecoding(); } + +Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() { + if (!cached_route_) { + if (call_.has_value()) { + Router::RouteConstSharedPtr route = + parent_.config_.routerConfig().route(call_.value().method_name_); + cached_route_ = std::move(route); + } else { + cached_route_ = nullptr; + } + } + + return cached_route_.value(); +} + +void ConnectionManager::ActiveRpc::sendLocalReply(ThriftFilters::DirectResponsePtr&& response) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(parent_.decoder_->protocolType()).createProtocol(); + Buffer::OwnedImpl buffer; + + response->encode(*proto, buffer); + + // Same logic as protocol above. + TransportPtr transport = + NamedTransportConfigFactory::getFactory(parent_.decoder_->transportType()).createTransport(); + transport->encodeFrame(response_buffer_, buffer); + + parent_.read_callbacks_->connection().write(response_buffer_, false); + parent_.doDeferredRpcDestroy(*this); +} + +void ConnectionManager::ActiveRpc::startUpstreamResponse(TransportType transport_type, + ProtocolType protocol_type) { + ASSERT(response_decoder_ == nullptr); + + response_decoder_ = std::make_unique(*this, transport_type, protocol_type); +} + +bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { + ASSERT(response_decoder_ != nullptr); + + try { + bool complete = response_decoder_->onData(buffer); + if (complete) { + parent_.doDeferredRpcDestroy(*this); + } + return complete; + } catch (const EnvoyException& ex) { + ENVOY_LOG(error, "thrift response error: {}", ex.what()); + parent_.stats_.response_decoding_error_.inc(); + + onError(ex.what()); + decoder_filter_->resetUpstreamConnection(); + return true; + } +} + +void ConnectionManager::ActiveRpc::resetDownstreamConnection() { + parent_.read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + parent_.doDeferredRpcDestroy(*this); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h new file mode 100644 index 0000000000000..82eb20feab176 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -0,0 +1,252 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/event/deferred_deletable.h" +#include "envoy/network/connection.h" +#include "envoy/network/filter.h" +#include "envoy/stats/stats.h" +#include "envoy/stats/timespan.h" + +#include "common/buffer/buffer_impl.h" +#include "common/common/linked_object.h" +#include "common/common/logger.h" + +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/protocol_converter.h" +#include "extensions/filters/network/thrift_proxy/stats.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Config is a configuration interface for ConnectionManager. + */ +class Config { +public: + virtual ~Config() {} + + virtual ThriftFilters::FilterChainFactory& filterFactory() PURE; + virtual ThriftFilterStats& stats() PURE; + virtual DecoderPtr createDecoder(DecoderCallbacks& callbacks) PURE; + virtual Router::Config& routerConfig() PURE; +}; + +/** + * ConnectionManager is a Network::Filter that will perform Thrift request handling on a connection. + */ +class ConnectionManager : public Network::ReadFilter, + public Network::ConnectionCallbacks, + public DecoderCallbacks, + Logger::Loggable { +public: + ConnectionManager(Config& config); + ~ConnectionManager(); + + // Network::ReadFilter + Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; + Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } + void initializeReadFilterCallbacks(Network::ReadFilterCallbacks&) override; + + // Network::ConnectionCallbacks + void onEvent(Network::ConnectionEvent) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + // DecoderCallbacks + ThriftFilters::DecoderFilter& newDecoderFilter() override; + +private: + class Message { + public: + Message(const std::string& method_name, MessageType msg_type, int32_t seq_id) + : method_name_(method_name), msg_type_(msg_type), seq_id_(seq_id) {} + + const std::string method_name_; + const MessageType msg_type_; + const int32_t seq_id_; + absl::optional success_; + }; + + struct ActiveRpc; + + struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { + ResponseDecoder(ActiveRpc& parent, TransportType transport_type, ProtocolType protocol_type) + : parent_(parent), + decoder_(std::make_unique( + NamedTransportConfigFactory::getFactory(transport_type).createTransport(), + NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(), *this)), + complete_(false), first_reply_field_(false) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + initProtocolConverter( + NamedProtocolConfigFactory::getFactory(parent_.parent_.decoder_->protocolType()) + .createProtocol(), + parent_.response_buffer_); + } + + bool onData(Buffer::Instance& data); + + // ProtocolConverter + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override; + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override; + ThriftFilters::FilterStatus transportBegin(absl::optional size) override { + UNREFERENCED_PARAMETER(size); + return ThriftFilters::FilterStatus::Continue; + } + ThriftFilters::FilterStatus transportEnd() override; + + // DecoderCallbacks + ThriftFilters::DecoderFilter& newDecoderFilter() override { return *this; } + + ActiveRpc& parent_; + DecoderPtr decoder_; + Buffer::OwnedImpl upstream_buffer_; + absl::optional reply_; + bool complete_ : 1; + bool first_reply_field_ : 1; + }; + typedef std::unique_ptr ResponseDecoderPtr; + + // ActiveRpc tracks request/response pairs. + struct ActiveRpc : LinkedObject, + public Event::DeferredDeletable, + public ThriftFilters::DecoderFilter, + public ThriftFilters::DecoderFilterCallbacks, + public ThriftFilters::FilterChainFactoryCallbacks { + ActiveRpc(ConnectionManager& parent) + : parent_(parent), request_timer_(new Stats::Timespan(parent_.stats_.request_time_ms_)), + stream_id_(parent_.stream_id_++) { + parent_.stats_.request_active_.inc(); + } + ~ActiveRpc() { + request_timer_->complete(); + parent_.stats_.request_active_.dec(); + + if (decoder_filter_ != nullptr) { + decoder_filter_->onDestroy(); + } + } + + // ThriftFilters::DecoderFilter + void onDestroy() override { NOT_IMPLEMENTED; } + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { + NOT_IMPLEMENTED; + } + void resetUpstreamConnection() override { NOT_IMPLEMENTED; } + ThriftFilters::FilterStatus transportBegin(absl::optional size) override { + return decoder_filter_->transportBegin(size); + } + ThriftFilters::FilterStatus transportEnd() override; + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override { + call_.emplace(std::string(name), msg_type, seq_id); + return decoder_filter_->messageBegin(name, msg_type, seq_id); + } + ThriftFilters::FilterStatus messageEnd() override { return decoder_filter_->messageEnd(); } + ThriftFilters::FilterStatus structBegin(absl::string_view name) override { + return decoder_filter_->structBegin(name); + } + ThriftFilters::FilterStatus structEnd() override { return decoder_filter_->structEnd(); } + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override { + return decoder_filter_->fieldBegin(name, field_type, field_id); + } + ThriftFilters::FilterStatus fieldEnd() override { return decoder_filter_->fieldEnd(); } + ThriftFilters::FilterStatus boolValue(bool value) override { + return decoder_filter_->boolValue(value); + } + ThriftFilters::FilterStatus byteValue(uint8_t value) override { + return decoder_filter_->byteValue(value); + } + ThriftFilters::FilterStatus int16Value(int16_t value) override { + return decoder_filter_->int16Value(value); + } + ThriftFilters::FilterStatus int32Value(int32_t value) override { + return decoder_filter_->int32Value(value); + } + ThriftFilters::FilterStatus int64Value(int64_t value) override { + return decoder_filter_->int64Value(value); + } + ThriftFilters::FilterStatus doubleValue(double value) override { + return decoder_filter_->doubleValue(value); + } + ThriftFilters::FilterStatus stringValue(absl::string_view value) override { + return decoder_filter_->stringValue(value); + } + ThriftFilters::FilterStatus mapBegin(FieldType key_type, FieldType value_type, + uint32_t size) override { + return decoder_filter_->mapBegin(key_type, value_type, size); + } + ThriftFilters::FilterStatus mapEnd() override { return decoder_filter_->mapEnd(); } + ThriftFilters::FilterStatus listBegin(FieldType elem_type, uint32_t size) override { + return decoder_filter_->listBegin(elem_type, size); + } + ThriftFilters::FilterStatus listEnd() override { return decoder_filter_->listEnd(); } + ThriftFilters::FilterStatus setBegin(FieldType elem_type, uint32_t size) override { + return decoder_filter_->setBegin(elem_type, size); + } + ThriftFilters::FilterStatus setEnd() override { return decoder_filter_->setEnd(); } + + // ThriftFilters::DecoderFilterCallbacks + uint64_t streamId() override { return stream_id_; } + const Network::Connection* connection() override; + void continueDecoding() override; + Router::RouteConstSharedPtr route() override; + TransportType downstreamTransportType() override { return parent_.decoder_->transportType(); } + ProtocolType downstreamProtocolType() override { return parent_.decoder_->protocolType(); } + void sendLocalReply(ThriftFilters::DirectResponsePtr&& response) override; + void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; + bool upstreamData(Buffer::Instance& buffer) override; + void resetDownstreamConnection() override; + + // Thrift::FilterChainFactoryCallbacks + void addDecoderFilter(ThriftFilters::DecoderFilterSharedPtr filter) override { + // TODO(zuercher): support multiple filters + filter->setDecoderFilterCallbacks(*this); + decoder_filter_ = filter; + } + + void createFilterChain(); + void onReset(); + void onError(const std::string& what); + + ConnectionManager& parent_; + Stats::TimespanPtr request_timer_; + uint64_t stream_id_; + ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ResponseDecoderPtr response_decoder_; + absl::optional cached_route_; + absl::optional call_; + Buffer::OwnedImpl response_buffer_; + }; + + typedef std::unique_ptr ActiveRpcPtr; + + void continueDecoding(); + void dispatch(); + void doDeferredRpcDestroy(ActiveRpc& rpc); + void resetAllRpcs(); + + Config& config_; + ThriftFilterStats& stats_; + + Network::ReadFilterCallbacks* read_callbacks_{}; + + DecoderPtr decoder_; + std::list rpcs_; + Buffer::OwnedImpl request_buffer_; + uint64_t stream_id_{1}; + bool stopped_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index d2cc0d898dcef..815a379edc9fd 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -13,69 +13,72 @@ namespace NetworkFilters { namespace ThriftProxy { // MessageBegin -> StructBegin -ProtocolState DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { std::string message_name; MessageType msg_type; int32_t seq_id; if (!proto_.readMessageBegin(buffer, message_name, msg_type, seq_id)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.clear(); stack_.emplace_back(Frame(ProtocolState::MessageEnd)); - return ProtocolState::StructBegin; + return DecoderStatus(ProtocolState::StructBegin, + filter_.messageBegin(absl::string_view(message_name), msg_type, seq_id)); } // MessageEnd -> Done -ProtocolState DecoderStateMachine::messageEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::messageEnd(Buffer::Instance& buffer) { if (!proto_.readMessageEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return ProtocolState::Done; + return DecoderStatus(ProtocolState::Done, filter_.messageEnd()); } // StructBegin -> FieldBegin -ProtocolState DecoderStateMachine::structBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::structBegin(Buffer::Instance& buffer) { std::string name; if (!proto_.readStructBegin(buffer, name)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return ProtocolState::FieldBegin; + return DecoderStatus(ProtocolState::FieldBegin, filter_.structBegin(absl::string_view(name))); } // StructEnd -> stack's return state -ProtocolState DecoderStateMachine::structEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::structEnd(Buffer::Instance& buffer) { if (!proto_.readStructEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.structEnd()); } // FieldBegin -> FieldValue, or // FieldBegin -> StructEnd (stop field) -ProtocolState DecoderStateMachine::fieldBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldBegin(Buffer::Instance& buffer) { std::string name; FieldType field_type; int16_t field_id; if (!proto_.readFieldBegin(buffer, name, field_type, field_id)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } if (field_type == FieldType::Stop) { - return ProtocolState::StructEnd; + return DecoderStatus(ProtocolState::StructEnd, ThriftFilters::FilterStatus::Continue); } stack_.emplace_back(Frame(ProtocolState::FieldEnd, field_type)); - return ProtocolState::FieldValue; + return DecoderStatus(ProtocolState::FieldValue, + filter_.fieldBegin(absl::string_view(name), field_type, field_id)); } // FieldValue -> FieldEnd (via stack return state) -ProtocolState DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); @@ -83,36 +86,36 @@ ProtocolState DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { } // FieldEnd -> FieldBegin -ProtocolState DecoderStateMachine::fieldEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldEnd(Buffer::Instance& buffer) { if (!proto_.readFieldEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } popReturnState(); - return ProtocolState::FieldBegin; + return DecoderStatus(ProtocolState::FieldBegin, filter_.fieldEnd()); } // ListBegin -> ListValue -ProtocolState DecoderStateMachine::listBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listBegin(Buffer::Instance& buffer) { FieldType elem_type; uint32_t size; if (!proto_.readListBegin(buffer, elem_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::ListEnd, elem_type, size)); - return ProtocolState::ListValue; + return DecoderStatus(ProtocolState::ListValue, filter_.listBegin(elem_type, size)); } // ListValue -> ListValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // ListValue -> ListEnd -ProtocolState DecoderStateMachine::listValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } frame.remaining_--; @@ -120,34 +123,35 @@ ProtocolState DecoderStateMachine::listValue(Buffer::Instance& buffer) { } // ListEnd -> stack's return state -ProtocolState DecoderStateMachine::listEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listEnd(Buffer::Instance& buffer) { if (!proto_.readListEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.listEnd()); } // MapBegin -> MapKey -ProtocolState DecoderStateMachine::mapBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapBegin(Buffer::Instance& buffer) { FieldType key_type, value_type; uint32_t size; if (!proto_.readMapBegin(buffer, key_type, value_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::MapEnd, key_type, value_type, size)); - return ProtocolState::MapKey; + return DecoderStatus(ProtocolState::MapKey, filter_.mapBegin(key_type, value_type, size)); } // MapKey -> MapValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on key type), or // MapKey -> MapEnd -ProtocolState DecoderStateMachine::mapKey(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapKey(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } return handleValue(buffer, frame.elem_type_, ProtocolState::MapValue); @@ -155,7 +159,7 @@ ProtocolState DecoderStateMachine::mapKey(Buffer::Instance& buffer) { // MapValue -> MapKey, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // MapValue -> MapKey -ProtocolState DecoderStateMachine::mapValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); ASSERT(frame.remaining_ != 0); @@ -165,34 +169,35 @@ ProtocolState DecoderStateMachine::mapValue(Buffer::Instance& buffer) { } // MapEnd -> stack's return state -ProtocolState DecoderStateMachine::mapEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapEnd(Buffer::Instance& buffer) { if (!proto_.readMapEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.mapEnd()); } // SetBegin -> SetValue -ProtocolState DecoderStateMachine::setBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setBegin(Buffer::Instance& buffer) { FieldType elem_type; uint32_t size; if (!proto_.readSetBegin(buffer, elem_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::SetEnd, elem_type, size)); - return ProtocolState::SetValue; + return DecoderStatus(ProtocolState::SetValue, filter_.setBegin(elem_type, size)); } // SetValue -> SetValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // SetValue -> SetEnd -ProtocolState DecoderStateMachine::setValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } frame.remaining_--; @@ -200,85 +205,88 @@ ProtocolState DecoderStateMachine::setValue(Buffer::Instance& buffer) { } // SetEnd -> stack's return state -ProtocolState DecoderStateMachine::setEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setEnd(Buffer::Instance& buffer) { if (!proto_.readSetEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.setEnd()); } -ProtocolState DecoderStateMachine::handleValue(Buffer::Instance& buffer, FieldType elem_type, - ProtocolState return_state) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::handleValue(Buffer::Instance& buffer, + FieldType elem_type, + ProtocolState return_state) { switch (elem_type) { - case FieldType::Bool: + case FieldType::Bool: { bool value; - if (!proto_.readBool(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readBool(buffer, value)) { + return DecoderStatus(return_state, filter_.boolValue(value)); } break; + } case FieldType::Byte: { uint8_t value; - if (!proto_.readByte(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readByte(buffer, value)) { + return DecoderStatus(return_state, filter_.byteValue(value)); } break; } case FieldType::I16: { int16_t value; - if (!proto_.readInt16(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readInt16(buffer, value)) { + return DecoderStatus(return_state, filter_.int16Value(value)); } break; } case FieldType::I32: { int32_t value; - if (!proto_.readInt32(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readInt32(buffer, value)) { + return DecoderStatus(return_state, filter_.int32Value(value)); } break; } case FieldType::I64: { int64_t value; - if (!proto_.readInt64(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readInt64(buffer, value)) { + return DecoderStatus(return_state, filter_.int64Value(value)); } break; } case FieldType::Double: { double value; - if (!proto_.readDouble(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readDouble(buffer, value)) { + return DecoderStatus(return_state, filter_.doubleValue(value)); } break; } case FieldType::String: { std::string value; - if (!proto_.readString(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readString(buffer, value)) { + return DecoderStatus(return_state, filter_.stringValue(value)); } break; } case FieldType::Struct: stack_.emplace_back(Frame(return_state)); - return ProtocolState::StructBegin; + return DecoderStatus(ProtocolState::StructBegin, ThriftFilters::FilterStatus::Continue); case FieldType::Map: stack_.emplace_back(Frame(return_state)); - return ProtocolState::MapBegin; + return DecoderStatus(ProtocolState::MapBegin, ThriftFilters::FilterStatus::Continue); case FieldType::List: stack_.emplace_back(Frame(return_state)); - return ProtocolState::ListBegin; + return DecoderStatus(ProtocolState::ListBegin, ThriftFilters::FilterStatus::Continue); case FieldType::Set: stack_.emplace_back(Frame(return_state)); - return ProtocolState::SetBegin; + return DecoderStatus(ProtocolState::SetBegin, ThriftFilters::FilterStatus::Continue); default: throw EnvoyException(fmt::format("unknown field type {}", static_cast(elem_type))); } - return return_state; + return DecoderStatus(ProtocolState::WaitForData); } -ProtocolState DecoderStateMachine::handleState(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::handleState(Buffer::Instance& buffer) { switch (state_) { case ProtocolState::MessageBegin: return messageBegin(buffer); @@ -328,61 +336,97 @@ ProtocolState DecoderStateMachine::popReturnState() { ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { while (state_ != ProtocolState::Done) { - ProtocolState s = handleState(buffer); - if (s == ProtocolState::WaitForData) { - return s; + DecoderStatus s = handleState(buffer); + if (s.next_state_ == ProtocolState::WaitForData) { + return ProtocolState::WaitForData; } - state_ = s; + state_ = s.next_state_; + + ASSERT(s.filter_status_.has_value()); + if (s.filter_status_.value() == ThriftFilters::FilterStatus::StopIteration) { + return ProtocolState::StopIteration; + } } return state_; } -Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol) - : transport_(std::move(transport)), protocol_(std::move(protocol)), state_machine_{}, - frame_started_(false) {} +Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks) + : transport_(std::move(transport)), protocol_(std::move(protocol)), callbacks_(callbacks) {} + +void Decoder::complete() { + request_.reset(); + state_machine_ = nullptr; + frame_started_ = false; + frame_ended_ = false; +} -void Decoder::onData(Buffer::Instance& data) { +ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ENVOY_LOG(debug, "thrift: {} bytes available", data.length()); + buffer_underflow = false; - while (true) { - if (!frame_started_) { - // Look for start of next frame. - if (!transport_->decodeFrameStart(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); - return; - } - ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); - - frame_started_ = true; - state_machine_ = std::make_unique(*protocol_); - } + if (frame_ended_) { + // Continuation after filter stopped iteration on transportComplete callback. + complete(); + buffer_underflow = (data.length() == 0); + return ThriftFilters::FilterStatus::Continue; + } - ASSERT(state_machine_ != nullptr); + if (!frame_started_) { + // Look for start of next frame. + absl::optional size{}; + if (!transport_->decodeFrameStart(data, size)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } + ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); - ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), - ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); + request_ = std::make_unique(callbacks_.newDecoderFilter()); + frame_started_ = true; + state_machine_ = std::make_unique(*protocol_, request_->filter_); - ProtocolState rv = state_machine_->run(data); - if (rv == ProtocolState::WaitForData) { - ENVOY_LOG(debug, "thrift: wait for data"); - return; + if (request_->filter_.transportBegin(size) == ThriftFilters::FilterStatus::StopIteration) { + return ThriftFilters::FilterStatus::StopIteration; } + } - ASSERT(rv == ProtocolState::Done); + ASSERT(state_machine_ != nullptr); - // Message complete, get decode end of frame. - if (!transport_->decodeFrameEnd(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); - return; - } - ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), + ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); - // Reset for next frame. - state_machine_ = nullptr; - frame_started_ = false; + ProtocolState rv = state_machine_->run(data); + if (rv == ProtocolState::WaitForData) { + ENVOY_LOG(debug, "thrift: wait for data"); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } else if (rv == ProtocolState::StopIteration) { + ENVOY_LOG(debug, "thrift: wait for continuation"); + return ThriftFilters::FilterStatus::StopIteration; } + + ASSERT(rv == ProtocolState::Done); + + // Message complete, decode end of frame. + if (!transport_->decodeFrameEnd(data)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } + + frame_ended_ = true; + + ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + if (request_->filter_.transportEnd() == ThriftFilters::FilterStatus::StopIteration) { + return ThriftFilters::FilterStatus::StopIteration; + } + + // Reset for next frame. + complete(); + buffer_underflow = (data.length() == 0); + return ThriftFilters::FilterStatus::Continue; } } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index a05d70cee30c8..26068858beb6d 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -6,6 +6,7 @@ #include "common/common/assert.h" #include "common/common/logger.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/protocol.h" #include "extensions/filters/network/thrift_proxy/transport.h" @@ -15,6 +16,7 @@ namespace NetworkFilters { namespace ThriftProxy { #define ALL_PROTOCOL_STATES(FUNCTION) \ + FUNCTION(StopIteration) \ FUNCTION(WaitForData) \ FUNCTION(MessageBegin) \ FUNCTION(MessageEnd) \ @@ -61,7 +63,8 @@ class ProtocolStateNameValues { */ class DecoderStateMachine { public: - DecoderStateMachine(Protocol& proto) : proto_(proto), state_(ProtocolState::MessageBegin) {} + DecoderStateMachine(Protocol& proto, ThriftFilters::DecoderFilter& filter) + : proto_(proto), filter_(filter), state_(ProtocolState::MessageBegin) {} /** * Consumes as much data from the configured Buffer as possible and executes the decoding state @@ -114,72 +117,107 @@ class DecoderStateMachine { uint32_t remaining_; }; + struct DecoderStatus { + DecoderStatus(ProtocolState next_state) : next_state_(next_state), filter_status_{} {}; + DecoderStatus(ProtocolState next_state, ThriftFilters::FilterStatus filter_status) + : next_state_(next_state), filter_status_(filter_status){}; + + ProtocolState next_state_; + absl::optional filter_status_; + }; + // These functions map directly to the matching ProtocolState values. Each returns the next state // or ProtocolState::WaitForData if more data is required. - ProtocolState messageBegin(Buffer::Instance& buffer); - ProtocolState messageEnd(Buffer::Instance& buffer); - ProtocolState structBegin(Buffer::Instance& buffer); - ProtocolState structEnd(Buffer::Instance& buffer); - ProtocolState fieldBegin(Buffer::Instance& buffer); - ProtocolState fieldValue(Buffer::Instance& buffer); - ProtocolState fieldEnd(Buffer::Instance& buffer); - ProtocolState listBegin(Buffer::Instance& buffer); - ProtocolState listValue(Buffer::Instance& buffer); - ProtocolState listEnd(Buffer::Instance& buffer); - ProtocolState mapBegin(Buffer::Instance& buffer); - ProtocolState mapKey(Buffer::Instance& buffer); - ProtocolState mapValue(Buffer::Instance& buffer); - ProtocolState mapEnd(Buffer::Instance& buffer); - ProtocolState setBegin(Buffer::Instance& buffer); - ProtocolState setValue(Buffer::Instance& buffer); - ProtocolState setEnd(Buffer::Instance& buffer); + DecoderStatus messageBegin(Buffer::Instance& buffer); + DecoderStatus messageEnd(Buffer::Instance& buffer); + DecoderStatus structBegin(Buffer::Instance& buffer); + DecoderStatus structEnd(Buffer::Instance& buffer); + DecoderStatus fieldBegin(Buffer::Instance& buffer); + DecoderStatus fieldValue(Buffer::Instance& buffer); + DecoderStatus fieldEnd(Buffer::Instance& buffer); + DecoderStatus listBegin(Buffer::Instance& buffer); + DecoderStatus listValue(Buffer::Instance& buffer); + DecoderStatus listEnd(Buffer::Instance& buffer); + DecoderStatus mapBegin(Buffer::Instance& buffer); + DecoderStatus mapKey(Buffer::Instance& buffer); + DecoderStatus mapValue(Buffer::Instance& buffer); + DecoderStatus mapEnd(Buffer::Instance& buffer); + DecoderStatus setBegin(Buffer::Instance& buffer); + DecoderStatus setValue(Buffer::Instance& buffer); + DecoderStatus setEnd(Buffer::Instance& buffer); // handleValue represents the generic Value state from the state machine documentation. It // returns either ProtocolState::WaitForData if more data is required or the next state. For // structs, lists, maps, or sets the return_state is pushed onto the stack and the next state is // based on elem_type. For primitive value types, return_state is returned as the next state // (unless WaitForData is returned). - ProtocolState handleValue(Buffer::Instance& buffer, FieldType elem_type, + DecoderStatus handleValue(Buffer::Instance& buffer, FieldType elem_type, ProtocolState return_state); // handleState delegates to the appropriate method based on state_. - ProtocolState handleState(Buffer::Instance& buffer); + DecoderStatus handleState(Buffer::Instance& buffer); // Helper method to retrieve the current frame's return state and remove the frame from the // stack. ProtocolState popReturnState(); Protocol& proto_; + ThriftFilters::DecoderFilter& filter_; ProtocolState state_; std::vector stack_; }; typedef std::unique_ptr DecoderStateMachinePtr; +class DecoderCallbacks { +public: + virtual ~DecoderCallbacks() {} + + /** + * @return DecoderFilter& a new DecoderFilter for a message. + */ + virtual ThriftFilters::DecoderFilter& newDecoderFilter() PURE; +}; + /** * Decoder encapsulates a configured TransportPtr and ProtocolPtr. */ class Decoder : public Logger::Loggable { public: - Decoder(TransportPtr&& transport, ProtocolPtr&& protocol); + Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks); + Decoder(TransportType transport_type, ProtocolType protocol_type, DecoderCallbacks& callbacks); /** - * Drains data from the given buffer while executing a DecoderStateMachine over the data. A new - * DecoderStateMachine is instantiated for each message. + * Drains data from the given buffer while executing a DecoderStateMachine over the data. * * @param data a Buffer containing Thrift protocol data + * @param buffer_underflow bool set to true if more data is required to continue decoding + * @return ThriftFilters::FilterStatus::StopIteration when waiting for filter continuation, + * Continue otherwise. * @throw EnvoyException on Thrift protocol errors */ - void onData(Buffer::Instance& data); + ThriftFilters::FilterStatus onData(Buffer::Instance& data, bool& buffer_underflow); - const Transport& transport() { return *transport_; } - const Protocol& protocol() { return *protocol_; } + TransportType transportType() { return transport_->type(); } + ProtocolType protocolType() { return protocol_->type(); } private: + struct ActiveRequest { + ActiveRequest(ThriftFilters::DecoderFilter& filter) : filter_(filter) {} + + ThriftFilters::DecoderFilter& filter_; + }; + typedef std::unique_ptr ActiveRequestPtr; + + void complete(); + TransportPtr transport_; ProtocolPtr protocol_; + DecoderCallbacks& callbacks_; + ActiveRequestPtr request_; DecoderStateMachinePtr state_machine_; - bool frame_started_; + bool frame_started_{false}; + bool frame_ended_{false}; }; typedef std::unique_ptr DecoderPtr; diff --git a/source/extensions/filters/network/thrift_proxy/filter.cc b/source/extensions/filters/network/thrift_proxy/filter.cc deleted file mode 100644 index 3244c153d961c..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/filter.cc +++ /dev/null @@ -1,310 +0,0 @@ -#include "extensions/filters/network/thrift_proxy/filter.h" - -#include "envoy/common/exception.h" - -#include "common/common/assert.h" - -#include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/protocol_impl.h" -#include "extensions/filters/network/thrift_proxy/transport_impl.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -Filter::Filter(const std::string& stat_prefix, Stats::Scope& scope) - : req_callbacks_(*this), resp_callbacks_(*this), stats_(generateStats(stat_prefix, scope)) {} - -Filter::~Filter() {} - -void Filter::onEvent(Network::ConnectionEvent event) { - if (active_call_map_.empty() && req_ == nullptr && resp_ == nullptr) { - return; - } - - if (event == Network::ConnectionEvent::RemoteClose) { - stats_.cx_destroy_local_with_active_rq_.inc(); - } - - if (event == Network::ConnectionEvent::LocalClose) { - stats_.cx_destroy_remote_with_active_rq_.inc(); - } -} - -Network::FilterStatus Filter::onData(Buffer::Instance& data, bool) { - if (!sniffing_) { - if (req_buffer_.length() > 0) { - // Stopped sniffing during response (in onWrite). Make sure leftover req_buffer_ contents are - // at the start of data or the upstream will see a corrupted request. - req_buffer_.move(data); - data.move(req_buffer_); - ASSERT(req_buffer_.length() == 0); - } - - return Network::FilterStatus::Continue; - } - - if (req_decoder_ == nullptr) { - req_decoder_ = std::make_unique(std::make_unique(req_callbacks_), - std::make_unique(req_callbacks_)); - } - - ENVOY_LOG(trace, "thrift: read {} bytes", data.length()); - req_buffer_.move(data); - - try { - BufferWrapper wrapped(req_buffer_); - - req_decoder_->onData(wrapped); - - // Move consumed portion of request back to data for the upstream to consume. - uint64_t pos = wrapped.position(); - if (pos > 0) { - data.move(req_buffer_, pos); - } - } catch (const EnvoyException& ex) { - ENVOY_LOG(error, "thrift error: {}", ex.what()); - req_decoder_.reset(); - data.move(req_buffer_); - stats_.request_decoding_error_.inc(); - sniffing_ = false; - } - - return Network::FilterStatus::Continue; -} - -Network::FilterStatus Filter::onWrite(Buffer::Instance& data, bool) { - if (!sniffing_) { - if (resp_buffer_.length() > 0) { - // Stopped sniffing during request (in onData). Make sure resp_buffer_ contents are at the - // start of data or the downstream will see a corrupted response. - resp_buffer_.move(data); - data.move(resp_buffer_); - ASSERT(resp_buffer_.length() == 0); - } - - return Network::FilterStatus::Continue; - } - - if (resp_decoder_ == nullptr) { - resp_decoder_ = std::make_unique(std::make_unique(resp_callbacks_), - std::make_unique(resp_callbacks_)); - } - - ENVOY_LOG(trace, "thrift wrote {} bytes", data.length()); - resp_buffer_.move(data); - - try { - BufferWrapper wrapped(resp_buffer_); - - resp_decoder_->onData(wrapped); - - // Move consumed portion of response back to data for the downstream to consume. - uint64_t pos = wrapped.position(); - if (pos > 0) { - data.move(resp_buffer_, pos); - } - } catch (const EnvoyException& ex) { - ENVOY_LOG(error, "thrift error: {}", ex.what()); - resp_decoder_.reset(); - data.move(resp_buffer_); - - stats_.response_decoding_error_.inc(); - sniffing_ = false; - } - - return Network::FilterStatus::Continue; -} - -void Filter::chargeDownstreamRequestStart(MessageType msg_type, int32_t seq_id) { - if (req_ != nullptr) { - throw EnvoyException("unexpected request messageStart callback"); - } - - if (active_call_map_.size() >= 64) { - throw EnvoyException("too many pending calls (64), disabling sniffing"); - } - - req_ = std::make_unique(*this, msg_type, seq_id); - - stats_.request_.inc(); - switch (msg_type) { - case MessageType::Call: - stats_.request_call_.inc(); - break; - case MessageType::Oneway: - stats_.request_oneway_.inc(); - break; - default: - stats_.request_invalid_type_.inc(); - break; - } -} - -void Filter::chargeDownstreamRequestComplete() { - if (req_ == nullptr) { - throw EnvoyException("unexpected request messageComplete callback"); - } - - // One-way messages do not receive responses. - if (req_->msg_type_ == MessageType::Oneway) { - req_.reset(); - return; - } - - int32_t seq_id = req_->seq_id_; - active_call_map_.emplace(seq_id, std::move(req_)); -} - -void Filter::chargeUpstreamResponseStart(MessageType msg_type, int32_t seq_id) { - if (resp_ != nullptr) { - throw EnvoyException("unexpected response messageStart callback"); - } - - auto i = active_call_map_.find(seq_id); - if (i == active_call_map_.end()) { - throw EnvoyException(fmt::format("unknown reply seq_id {}", seq_id)); - } - - resp_ = std::move(i->second); - resp_->response_msg_type_ = msg_type; - active_call_map_.erase(i); -} - -void Filter::chargeUpstreamResponseField(FieldType field_type, int16_t field_id) { - if (resp_ == nullptr) { - throw EnvoyException("unexpected response messageField callback"); - } - - if (resp_->response_msg_type_ != MessageType::Reply) { - // If this is not a reply, we'll count an exception instead of an error, so leave - // resp_->success_ unset. - return; - } - - if (resp_->success_.has_value()) { - // If resp->success_ is already set, leave the existing value. - return; - } - - // Successful replies have a single field, with field_id 0 that contains the response value. - // IDL-level exceptions are encoded as a single field with field_id >= 1. - resp_->success_ = field_id == 0 && field_type != FieldType::Stop; -} - -void Filter::chargeUpstreamResponseComplete() { - if (resp_ == nullptr) { - throw EnvoyException("unexpected response messageComplete callback"); - } - - stats_.response_.inc(); - switch (resp_->response_msg_type_) { - case MessageType::Reply: - stats_.response_reply_.inc(); - break; - case MessageType::Exception: - stats_.response_exception_.inc(); - break; - default: - stats_.response_invalid_type_.inc(); - break; - } - - if (resp_->success_.has_value()) { - if (resp_->success_.value()) { - stats_.response_success_.inc(); - } else { - stats_.response_error_.inc(); - } - } - - resp_.reset(); -} - -void Filter::RequestCallbacks::transportFrameStart(absl::optional size) { - UNREFERENCED_PARAMETER(size); - ENVOY_LOG(debug, "thrift request: started {} frame", parent_.req_decoder_->transport().name()); -} - -void Filter::RequestCallbacks::transportFrameComplete() { - ENVOY_LOG(debug, "thrift request: ended {} frame", parent_.req_decoder_->transport().name()); -} - -void Filter::RequestCallbacks::messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) { - ENVOY_LOG(debug, "thrift request: started {} message {}: {}", - parent_.req_decoder_->protocol().name(), name, seq_id); - parent_.chargeDownstreamRequestStart(msg_type, seq_id); -} - -void Filter::RequestCallbacks::structBegin(const absl::string_view name) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift request: started {} struct", parent_.req_decoder_->protocol().name()); -} - -void Filter::RequestCallbacks::structField(const absl::string_view name, FieldType field_type, - int16_t field_id) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift request: started {} field {}, type {}", - parent_.req_decoder_->protocol().name(), field_id, static_cast(field_type)); -} - -void Filter::RequestCallbacks::structEnd() { - ENVOY_LOG(debug, "thrift request: ended {} struct", parent_.req_decoder_->protocol().name()); -} - -void Filter::RequestCallbacks::messageComplete() { - ENVOY_LOG(debug, "thrift request: ended {} message", parent_.req_decoder_->protocol().name()); - parent_.chargeDownstreamRequestComplete(); -} - -void Filter::ResponseCallbacks::transportFrameStart(absl::optional size) { - UNREFERENCED_PARAMETER(size); - ENVOY_LOG(debug, "thrift response: started {} frame", parent_.resp_decoder_->transport().name()); -} - -void Filter::ResponseCallbacks::transportFrameComplete() { - ENVOY_LOG(debug, "thrift response: ended {} frame", parent_.resp_decoder_->transport().name()); -} - -void Filter::ResponseCallbacks::messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) { - ENVOY_LOG(debug, "thrift response: started {} message {}: {}", - parent_.resp_decoder_->protocol().name(), name, seq_id); - parent_.chargeUpstreamResponseStart(msg_type, seq_id); -} - -void Filter::ResponseCallbacks::structBegin(const absl::string_view name) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift response: started {} struct", parent_.req_decoder_->protocol().name()); - depth_++; -} - -void Filter::ResponseCallbacks::structField(const absl::string_view name, FieldType field_type, - int16_t field_id) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift response: started {} field {}, type {}", - parent_.req_decoder_->protocol().name(), field_id, static_cast(field_type)); - - if (depth_ == 1) { - // Only care about the outermost struct, which corresponds to the success or failure of the - // request. - parent_.chargeUpstreamResponseField(field_type, field_id); - } -} - -void Filter::ResponseCallbacks::structEnd() { - ENVOY_LOG(debug, "thrift request: ended {} struct", parent_.req_decoder_->protocol().name()); - depth_--; -} - -void Filter::ResponseCallbacks::messageComplete() { - ENVOY_LOG(debug, "thrift response: ended {} message", parent_.resp_decoder_->protocol().name()); - parent_.chargeUpstreamResponseComplete(); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filter.h b/source/extensions/filters/network/thrift_proxy/filter.h deleted file mode 100644 index a0f028ed0221d..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/filter.h +++ /dev/null @@ -1,173 +0,0 @@ -#pragma once - -#include - -#include "envoy/network/connection.h" -#include "envoy/network/filter.h" -#include "envoy/stats/stats.h" -#include "envoy/stats/stats_macros.h" -#include "envoy/stats/timespan.h" - -#include "common/buffer/buffer_impl.h" -#include "common/common/logger.h" - -#include "extensions/filters/network/thrift_proxy/decoder.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -/** - * All thrift filter stats. @see stats_macros.h - */ -// clang-format off -#define ALL_THRIFT_FILTER_STATS(COUNTER, GAUGE, HISTOGRAM) \ - COUNTER(request) \ - COUNTER(request_call) \ - COUNTER(request_oneway) \ - COUNTER(request_invalid_type) \ - GAUGE(request_active) \ - COUNTER(request_decoding_error) \ - HISTOGRAM(request_time_ms) \ - COUNTER(response) \ - COUNTER(response_reply) \ - COUNTER(response_success) \ - COUNTER(response_error) \ - COUNTER(response_exception) \ - COUNTER(response_invalid_type) \ - COUNTER(response_decoding_error) \ - COUNTER(cx_destroy_local_with_active_rq) \ - COUNTER(cx_destroy_remote_with_active_rq) -// clang-format on - -/** - * Struct definition for all mongo proxy stats. @see stats_macros.h - */ -struct ThriftFilterStats { - ALL_THRIFT_FILTER_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, GENERATE_HISTOGRAM_STRUCT) -}; - -/** - * A sniffing filter for thrift traffic. - */ -class Filter : public Network::Filter, - public Network::ConnectionCallbacks, - Logger::Loggable { -public: - Filter(const std::string& stat_prefix, Stats::Scope& scope); - ~Filter(); - - // Network::ReadFilter - Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; - Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } - void initializeReadFilterCallbacks(Network::ReadFilterCallbacks&) override {} - - // Network::WriteFilter - Network::FilterStatus onWrite(Buffer::Instance& data, bool end_stream) override; - - // Network::ConnectionCallbacks - void onEvent(Network::ConnectionEvent) override; - void onAboveWriteBufferHighWatermark() override {} - void onBelowWriteBufferLowWatermark() override {} - -private: - // RequestCallbacks handles callbacks related to decoding downstream requests. - class RequestCallbacks : public virtual ProtocolCallbacks, public virtual TransportCallbacks { - public: - RequestCallbacks(Filter& parent) : parent_(parent) {} - - // TransportCallbacks - void transportFrameStart(absl::optional size) override; - void transportFrameComplete() override; - - // ProtocolCallbacks - void messageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) override; - void structBegin(const absl::string_view name) override; - void structField(const absl::string_view name, FieldType field_type, int16_t field_id) override; - void structEnd() override; - void messageComplete() override; - - private: - Filter& parent_; - }; - - // ResponseCallbacks handles callbacks related to decoding upstream responses. - class ResponseCallbacks : public virtual ProtocolCallbacks, public virtual TransportCallbacks { - public: - ResponseCallbacks(Filter& parent) : parent_(parent) {} - - // TransportCallbacks - void transportFrameStart(absl::optional size) override; - void transportFrameComplete() override; - - // ProtocolCallbacks - void messageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) override; - void structBegin(const absl::string_view name) override; - void structField(const absl::string_view name, FieldType field_type, int16_t field_id) override; - void structEnd() override; - void messageComplete() override; - - private: - Filter& parent_; - int depth_{0}; - }; - - // ActiveMessage tracks downstream requests for which no response has been received. - struct ActiveMessage { - ActiveMessage(Filter& parent, MessageType msg_type, int32_t seq_id) - : parent_(parent), request_timer_(new Stats::Timespan(parent_.stats_.request_time_ms_)), - msg_type_(msg_type), seq_id_(seq_id) { - parent_.stats_.request_active_.inc(); - } - ~ActiveMessage() { - request_timer_->complete(); - parent_.stats_.request_active_.dec(); - } - - Filter& parent_; - Stats::TimespanPtr request_timer_; - const MessageType msg_type_; - const int32_t seq_id_; - MessageType response_msg_type_{}; - absl::optional success_{}; - }; - typedef std::unique_ptr ActiveMessagePtr; - - ThriftFilterStats generateStats(const std::string& prefix, Stats::Scope& scope) { - return ThriftFilterStats{ALL_THRIFT_FILTER_STATS(POOL_COUNTER_PREFIX(scope, prefix), - POOL_GAUGE_PREFIX(scope, prefix), - POOL_HISTOGRAM_PREFIX(scope, prefix))}; - } - - void chargeDownstreamRequestStart(MessageType msg_type, int32_t seq_id); - void chargeDownstreamRequestComplete(); - void chargeUpstreamResponseStart(MessageType msg_type, int32_t seq_id); - void chargeUpstreamResponseField(FieldType field_type, int16_t field_id); - void chargeUpstreamResponseComplete(); - - // Downstream request decoder, callbacks, and buffer. - DecoderPtr req_decoder_{}; - RequestCallbacks req_callbacks_; - Buffer::OwnedImpl req_buffer_; - // Request currently being decoded. - ActiveMessagePtr req_; - - // Upstream response decoder, callbacks, and buffer. - DecoderPtr resp_decoder_{}; - ResponseCallbacks resp_callbacks_; - Buffer::OwnedImpl resp_buffer_; - // Response currently being decoded. - ActiveMessagePtr resp_; - - // List of active request messages. - std::unordered_map active_call_map_; - - bool sniffing_{true}; - ThriftFilterStats stats_; -}; - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/BUILD b/source/extensions/filters/network/thrift_proxy/filters/BUILD new file mode 100644 index 0000000000000..7f374b1fe5fd1 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/BUILD @@ -0,0 +1,50 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "filter_config_interface", + hdrs = ["filter_config.h"], + deps = [ + ":filter_interface", + "//include/envoy/server:filter_config_interface", + "//source/common/common:macros", + "//source/common/protobuf:cc_wkt_protos", + ], +) + +envoy_cc_library( + name = "factory_base_lib", + hdrs = ["factory_base.h"], + deps = [ + ":filter_config_interface", + "//source/common/protobuf:utility_lib", + ], +) + +envoy_cc_library( + name = "filter_interface", + hdrs = ["filter.h"], + external_deps = ["abseil_optional"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/network:connection_interface", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + ], +) + +envoy_cc_library( + name = "well_known_names", + hdrs = ["well_known_names.h"], + deps = [ + "//source/common/singleton:const_singleton", + ], +) diff --git a/source/extensions/filters/network/thrift_proxy/filters/factory_base.h b/source/extensions/filters/network/thrift_proxy/filters/factory_base.h new file mode 100644 index 0000000000000..bf2bb292f043b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/factory_base.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common/protobuf/utility.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter_config.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +template class FactoryBase : public NamedThriftFilterConfigFactory { +public: + FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message& proto_config, + const std::string& stats_prefix, + Server::Configuration::FactoryContext& context) override { + return createFilterFactoryFromProtoTyped( + MessageUtil::downcastAndValidate(proto_config), stats_prefix, context); + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return name_; } + +protected: + FactoryBase(const std::string& name) : name_(name) {} + +private: + virtual FilterFactoryCb + createFilterFactoryFromProtoTyped(const ConfigProto& proto_config, + const std::string& stats_prefix, + Server::Configuration::FactoryContext& context) PURE; + + const std::string name_; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h new file mode 100644 index 0000000000000..d677eaa0711eb --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -0,0 +1,306 @@ +#pragma once + +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/network/connection.h" + +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) PURE; +}; + +typedef std::unique_ptr DirectResponsePtr; + +/** + * Decoder filter callbacks add additional callbacks. + */ +class DecoderFilterCallbacks { +public: + virtual ~DecoderFilterCallbacks() {} + + /** + * @return uint64_t the ID of the originating stream for logging purposes. + */ + virtual uint64_t streamId() PURE; + + /** + * @return const Network::Connection* the originating connection, or nullptr if there is none. + */ + virtual const Network::Connection* connection() PURE; + + /** + * Continue iterating through the filter chain with buffered data. This routine can only be + * called if the filter has previously returned StopIteration from one of the DecoderFilter + * methods. The connection manager will callbacks to the next filter in the chain. Further note + * that if the request is not complete, the calling filter may receive further callbacks and must + * return an appropriate status code depending on what the filter needs to do. + */ + virtual void continueDecoding() PURE; + + /** + * @return RouteConstSharedPtr the route for the current request. + */ + virtual Router::RouteConstSharedPtr route() PURE; + + /** + * @return TransportType the originating transport. + */ + virtual TransportType downstreamTransportType() PURE; + + /** + * @return ProtocolType the originating protocol. + */ + virtual ProtocolType downstreamProtocolType() PURE; + + /** + * Create a locally generated response using the provided response object. + * @param response DirectResponsePtr the response to send to the downstream client + */ + virtual void sendLocalReply(DirectResponsePtr&& response) PURE; + + /** + * Indicates the start of an upstream response. May only be called once. + * @param transport_type TransportType the upstream is using + * @param protocol_type ProtocolType the upstream is using + */ + virtual void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) PURE; + + /** + * Called with upstream response data. + * @param data supplies the upstream's data + * @return true if the upstream response is complete; false if more data is expected + */ + virtual bool upstreamData(Buffer::Instance& data) PURE; + + /** + * Reset the downstream connection. + */ + virtual void resetDownstreamConnection() PURE; +}; + +enum class FilterStatus { + // Continue filter chain iteration. + Continue, + + // Stop iterating over filters in the filter chain. Iteration must be explicitly restarted via + // continueDecoding(). + StopIteration +}; + +/** + * Decoder filter interface. + */ +class DecoderFilter { +public: + virtual ~DecoderFilter() {} + + /** + * This routine is called prior to a filter being destroyed. This may happen after normal stream + * finish (both downstream and upstream) or due to reset. Every filter is responsible for making + * sure that any async events are cleaned up in the context of this routine. This includes timers, + * network calls, etc. The reason there is an onDestroy() method vs. doing this type of cleanup + * in the destructor is due to the deferred deletion model that Envoy uses to avoid stack unwind + * complications. Filters must not invoke either encoder or decoder filter callbacks after having + * onDestroy() invoked. + */ + virtual void onDestroy() PURE; + + /** + * Called by the connection manager once to initialize the filter decoder callbacks that the + * filter should use. Callbacks will not be invoked by the filter after onDestroy() is called. + */ + virtual void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) PURE; + + /** + * Resets the upstream connection. + */ + virtual void resetUpstreamConnection() PURE; + + /** + * Indicates the start of a Thrift transport frame was detected. Unframed transports generate + * simulated start messages. + * @param size the size of the message, if available to the transport + */ + virtual FilterStatus transportBegin(absl::optional size) PURE; + + /** + * Indicates the end of a Thrift transport frame was detected. Unframed transport generate + * simulated complete messages. + */ + virtual FilterStatus transportEnd() PURE; + + /** + * Indicates that the start of a Thrift protocol message was detected. + * @param name the name of the message, if available + * @param msg_type the type of the message + * @param seq_id the message sequence id + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) PURE; + + /** + * Indicates that the end of a Thrift protocol message was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus messageEnd() PURE; + + /** + * Indicates that the start of a Thrift protocol struct was detected. + * @param name the name of the struct, if available + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus structBegin(absl::string_view name) PURE; + + /** + * Indicates that the end of a Thrift protocol struct was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus structEnd() PURE; + + /** + * Indicates that the start of Thrift protocol struct field was detected. + * @param name the name of the field, if available + * @param field_type the type of the field + * @param field_id the field id + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) PURE; + + /** + * Indicates that the end of a Thrift protocol struct field was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus fieldEnd() PURE; + + /** + * A struct field, map key, map value, list element or set element was detected. + * @param value type value of the field + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus boolValue(bool value) PURE; + virtual FilterStatus byteValue(uint8_t value) PURE; + virtual FilterStatus int16Value(int16_t value) PURE; + virtual FilterStatus int32Value(int32_t value) PURE; + virtual FilterStatus int64Value(int64_t value) PURE; + virtual FilterStatus doubleValue(double value) PURE; + virtual FilterStatus stringValue(absl::string_view value) PURE; + + /** + * Indicates the start of a Thrift protocol map was detected. + * @param key_type the map key type + * @param value_type the map value type + * @param size the number of key-value pairs + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus mapBegin(FieldType key_type, FieldType value_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol map was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus mapEnd() PURE; + + /** + * Indicates the start of a Thrift protocol list was detected. + * @param elem_type the list value type + * @param size the number of values in the list + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus listBegin(FieldType elem_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol list was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus listEnd() PURE; + + /** + * Indicates the start of a Thrift protocol set was detected. + * @param elem_type the set value type + * @param size the number of values in the set + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus setBegin(FieldType elem_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol set was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus setEnd() PURE; +}; + +typedef std::shared_ptr DecoderFilterSharedPtr; + +/** + * These callbacks are provided by the connection manager to the factory so that the factory can + * build the filter chain in an application specific way. + */ +class FilterChainFactoryCallbacks { +public: + virtual ~FilterChainFactoryCallbacks() {} + + /** + * Add a decoder filter that is used when reading connection data. + * @param filter supplies the filter to add. + */ + virtual void addDecoderFilter(DecoderFilterSharedPtr filter) PURE; +}; + +/** + * This function is used to wrap the creation of a Thrift filter chain for new connections as they + * come in. Filter factories create the function at configuration initialization time, and then + * they are used at runtime. + * @param callbacks supplies the callbacks for the stream to install filters to. Typically the + * function will install a single filter, but it's technically possibly to install more than one + * if desired. + */ +typedef std::function FilterFactoryCb; + +/** + * A FilterChainFactory is used by a connection manager to create a Thrift level filter chain when + * a new connection is created. Typically it would be implemented by a configuration engine that + * would install a set of filters that are able to process an application scenario on top of a + * stream of Thrift requests. + */ +class FilterChainFactory { +public: + virtual ~FilterChainFactory() {} + + /** + * Called when a new Thrift stream is created on the connection. + * @param callbacks supplies the "sink" that is used for actually creating the filter chain. @see + * FilterChainFactoryCallbacks. + */ + virtual void createFilterChain(FilterChainFactoryCallbacks& callbacks) PURE; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter_config.h b/source/extensions/filters/network/thrift_proxy/filters/filter_config.h new file mode 100644 index 0000000000000..86f4b7730517b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/filter_config.h @@ -0,0 +1,55 @@ +#pragma once + +#include "envoy/server/filter_config.h" + +#include "common/common/macros.h" +#include "common/protobuf/protobuf.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +/** + * Implemented by each Thrift filter and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedThriftFilterConfigFactory { +public: + virtual ~NamedThriftFilterConfigFactory() {} + + /** + * Create a particular thrift filter factory implementation. If the implementation is unable to + * produce a factory with the provided parameters, it should throw an EnvoyException in the case + * of general error. The returned callback should always be initialized. + * @param config supplies the configuration for the filter + * @param stat_prefix prefix for stat logging + * @param context supplies the filter's context. + * @return FilterFactoryCb the factory creation function. + */ + virtual FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message& config, const std::string& stat_prefix, + Server::Configuration::FactoryContext& context) PURE; + + /** + * @return ProtobufTypes::MessagePtr create empty config proto message for v2. The filter + * config, which arrives in an opaque google.protobuf.Struct message, will be converted to + * JSON and then parsed into this empty proto. + */ + virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; + + /** + * @return std::string the identifying name for a particular implementation of a thrift filter + * produced by the factory. + */ + virtual std::string name() PURE; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h b/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h new file mode 100644 index 0000000000000..41abac8d54753 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h @@ -0,0 +1,25 @@ +#pragma once + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +/** + * Well-known http filter names. + * NOTE: New filters should use the well known name: envoy.filters.thrift.name. + */ +class ThriftFilterNameValues { +public: + // Router filter + const std::string ROUTER = "envoy.filters.thrift.router"; +}; + +typedef ConstSingleton ThriftFilterNames; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc index b66863e9bfbfb..a45861a349e4b 100644 --- a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc @@ -10,28 +10,26 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { +bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer, + absl::optional& size) { if (buffer.length() < 4) { return false; } - int32_t size = BufferHelper::peekI32(buffer); + int32_t thrift_size = BufferHelper::peekI32(buffer); - if (size <= 0 || size > MaxFrameSize) { - throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", size)); + if (thrift_size <= 0 || thrift_size > MaxFrameSize) { + throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", thrift_size)); } - onFrameStart(absl::optional(static_cast(size))); - buffer.drain(4); - return true; -} -bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { - onFrameComplete(); + size = static_cast(thrift_size); return true; } +bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { return true; } + void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) { uint64_t size = message.length(); if (size == 0 || size > MaxFrameSize) { @@ -44,6 +42,17 @@ void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance buffer.move(message); } +class FramedTransportConfigFactory : public TransportFactoryBase { +public: + FramedTransportConfigFactory() : TransportFactoryBase(TransportNames::get().FRAMED) {} +}; + +/** + * Static registration for the framed transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h index b51b5b3819cd2..4c7569487ea38 100644 --- a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h @@ -17,13 +17,14 @@ namespace ThriftProxy { * FramedTransportImpl implements the Thrift Framed transport. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md */ -class FramedTransportImpl : public TransportImplBase { +class FramedTransportImpl : public Transport { public: - FramedTransportImpl(TransportCallbacks& callbacks) : TransportImplBase(callbacks) {} + FramedTransportImpl() {} // Transport const std::string& name() const override { return TransportNames::get().FRAMED; } - bool decodeFrameStart(Buffer::Instance& buffer) override; + TransportType type() const override { return TransportType::Framed; } + bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; bool decodeFrameEnd(Buffer::Instance& buffer) override; void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 472f2d1852d51..37218b5586637 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -5,7 +5,10 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" +#include "envoy/registry/registry.h" +#include "common/common/assert.h" +#include "common/config/utility.h" #include "common/singleton/const_singleton.h" #include "absl/strings/string_view.h" @@ -15,6 +18,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +enum class ProtocolType { + Binary, + LaxBinary, + Compact, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST PROTOCOL TYPE + LastProtocolType = Auto, +}; + /** * Names of available Protocol implementations. */ @@ -29,11 +42,23 @@ class ProtocolNameValues { // Compact protocol const std::string COMPACT = "compact"; - // JSON protocol - const std::string JSON = "json"; - // Auto-detection protocol const std::string AUTO = "auto"; + + const std::string& fromType(ProtocolType type) const { + switch (type) { + case ProtocolType::Binary: + return BINARY; + case ProtocolType::LaxBinary: + return LAX_BINARY; + case ProtocolType::Compact: + return COMPACT; + case ProtocolType::Auto: + return AUTO; + default: + NOT_REACHED; + } + } }; typedef ConstSingleton ProtocolNames; @@ -75,48 +100,6 @@ enum class FieldType { LastFieldType = List, }; -/** - * ProtocolCallbacks are Thrift protocol-level callbacks. - */ -class ProtocolCallbacks { -public: - virtual ~ProtocolCallbacks() {} - - /** - * Indicates that the start of a Thrift protocol message was detected. - * @param name the name of the message, if available - * @param msg_type the type of the message - * @param seq_id the message sequence id - */ - virtual void messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) PURE; - - /** - * Indicates that the start of a Thrift protocol struct was detected. - * @param name the name of the struct, if available - */ - virtual void structBegin(const absl::string_view name) PURE; - - /** - * Indicates that the start of Thrift protocol struct field was detected. - * @param name the name of the field, if available - * @param field_type the type of the field - * @param field_id the field id - */ - virtual void structField(const absl::string_view name, FieldType field_type, - int16_t field_id) PURE; - - /** - * Indicates that the end of a Thrift protocol struct was detected. - */ - virtual void structEnd() PURE; - - /** - * Indicates that the end of a Thrift protocol message was detected. - */ - virtual void messageComplete() PURE; -}; - /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -125,8 +108,16 @@ class Protocol { public: virtual ~Protocol() {} + /** + * @return const std::string& the human-readable name of the protocol + */ virtual const std::string& name() const PURE; + /** + * @return ProtocolType the protocol type + */ + virtual ProtocolType type() const PURE; + /** * Reads the start of a Thrift protocol message from the buffer and updates the name, msg_type, * and seq_id parameters with values from the message header. If successful, the message header @@ -485,6 +476,52 @@ class Protocol { typedef std::unique_ptr ProtocolPtr; +/** + * Implemented by each Thrift protocol and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedProtocolConfigFactory { +public: + virtual ~NamedProtocolConfigFactory() {} + + /** + * Create a particular Thrift protocol + * @return ProtocolFactoryCb the protocol + */ + virtual ProtocolPtr createProtocol() PURE; + + /** + * @return std::string the identifying name for a particular implementation of thrift protocol + * produced by the factory. + */ + virtual std::string name() PURE; + + /** + * Convenience method to lookup a factory by type. + * @param ProtocolType the protocol type + * @return NamedProtocolConfigFactory& for the ProtocolType + */ + static NamedProtocolConfigFactory& getFactory(ProtocolType type) { + const std::string& name = ProtocolNames::get().fromType(type); + return Envoy::Config::Utility::getAndCheckFactory(name); + } +}; + +/** + * ProtocolFactoryBase provides a template for a trivial NamedProtocolConfigFactory. + */ +template class ProtocolFactoryBase : public NamedProtocolConfigFactory { + ProtocolPtr createProtocol() override { return std::move(std::make_unique()); } + + std::string name() override { return name_; } + +protected: + ProtocolFactoryBase(const std::string& name) : name_(name) {} + +private: + const std::string name_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h new file mode 100644 index 0000000000000..87380eb5f3b89 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -0,0 +1,144 @@ +#pragma once + +#include "envoy/buffer/buffer.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ProtocolConverter is an abstract class that implements protocol-related methods on + * ThriftFilters::DecoderFilter in terms of converting the decoded messages into a different + * protocol. + */ +class ProtocolConverter : public ThriftFilters::DecoderFilter { +public: + ProtocolConverter() {} + ~ProtocolConverter() {} + + void initProtocolConverter(ProtocolPtr&& proto, Buffer::Instance& buffer) { + proto_ = std::move(proto); + buffer_ = &buffer; + } + + // ThiftFilters::DecoderFilter + void onDestroy() override { NOT_IMPLEMENTED; } + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { + NOT_IMPLEMENTED; + } + void resetUpstreamConnection() override { NOT_IMPLEMENTED; } + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override { + proto_->writeMessageBegin(*buffer_, std::string(name), msg_type, seq_id); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus messageEnd() override { + proto_->writeMessageEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus structBegin(absl::string_view name) override { + proto_->writeStructBegin(*buffer_, std::string(name)); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus structEnd() override { + proto_->writeFieldBegin(*buffer_, "", FieldType::Stop, 0); + proto_->writeStructEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override { + proto_->writeFieldBegin(*buffer_, std::string(name), field_type, field_id); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus fieldEnd() override { + proto_->writeFieldEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus boolValue(bool value) override { + proto_->writeBool(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus byteValue(uint8_t value) override { + proto_->writeByte(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int16Value(int16_t value) override { + proto_->writeInt16(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int32Value(int32_t value) override { + proto_->writeInt32(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int64Value(int64_t value) override { + proto_->writeInt64(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus doubleValue(double value) override { + proto_->writeDouble(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus stringValue(absl::string_view value) override { + proto_->writeString(*buffer_, std::string(value)); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus mapBegin(FieldType key_type, FieldType value_type, + uint32_t size) override { + proto_->writeMapBegin(*buffer_, key_type, value_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus mapEnd() override { + proto_->writeMapEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus listBegin(FieldType elem_type, uint32_t size) override { + proto_->writeListBegin(*buffer_, elem_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus listEnd() override { + proto_->writeListEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus setBegin(FieldType elem_type, uint32_t size) override { + proto_->writeSetBegin(*buffer_, elem_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus setEnd() override { + proto_->writeSetEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + +protected: + ProtocolType protocolType() const { return proto_->type(); } + +private: + ProtocolPtr proto_; + Buffer::Instance* buffer_{}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc index f2cac638d2925..46636099f5430 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc @@ -26,9 +26,9 @@ bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& n uint16_t version = BufferHelper::peekU16(buffer); if (BinaryProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique(callbacks_)); + setProtocol(std::make_unique()); } else if (CompactProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique(callbacks_)); + setProtocol(std::make_unique()); } else { throw EnvoyException( fmt::format("unknown thrift auto protocol message start {:04x}", version)); @@ -45,6 +45,16 @@ bool AutoProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { return protocol_->readMessageEnd(buffer); } +class AutoProtocolConfigFactory : public ProtocolFactoryBase { +public: + AutoProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().AUTO) {} +}; + +/** + * Static registration for the auto protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.h b/source/extensions/filters/network/thrift_proxy/protocol_impl.h index def5061fa8e9e..0eb374e9f9869 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.h @@ -13,39 +13,24 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/* - * ProtocolImplBase provides a base class for Protocol implementations. - */ -class ProtocolImplBase : public virtual Protocol { -public: - ProtocolImplBase(ProtocolCallbacks& callbacks) : callbacks_(callbacks) {} - -protected: - void onMessageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) const { - callbacks_.messageStart(name, msg_type, seq_id); - } - void onStructBegin(const absl::string_view name) const { callbacks_.structBegin(name); } - void onStructField(const absl::string_view name, FieldType field_type, int16_t field_id) const { - callbacks_.structField(name, field_type, field_id); - } - void onStructEnd() const { callbacks_.structEnd(); } - void onMessageComplete() const { callbacks_.messageComplete(); } - - ProtocolCallbacks& callbacks_; -}; - /** * AutoProtocolImpl attempts to distinguish between the Thrift binary (strict mode only) and * compact protocols and then delegates subsequent decoding operations to the appropriate Protocol * implementation. */ -class AutoProtocolImpl : public ProtocolImplBase { +class AutoProtocolImpl : public Protocol { public: - AutoProtocolImpl(ProtocolCallbacks& callbacks) - : ProtocolImplBase(callbacks), name_(ProtocolNames::get().AUTO) {} + AutoProtocolImpl() : name_(ProtocolNames::get().AUTO) {} // Protocol const std::string& name() const override { return name_; } + ProtocolType type() const override { + if (protocol_ != nullptr) { + return protocol_->type(); + } + return ProtocolType::Auto; + } + bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD new file mode 100644 index 0000000000000..2d32db0ee154c --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -0,0 +1,51 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + ":router_lib", + "//include/envoy/registry", + "//source/extensions/filters/network/thrift_proxy/filters:factory_base_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_config_interface", + "//source/extensions/filters/network/thrift_proxy/filters:well_known_names", + "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1/router:router_cc", + ], +) + +envoy_cc_library( + name = "router_interface", + hdrs = ["router.h"], + external_deps = ["abseil_optional"], + deps = [], +) + +envoy_cc_library( + name = "router_lib", + srcs = ["router_impl.cc"], + hdrs = ["router_impl.h"], + deps = [ + ":router_interface", + "//include/envoy/tcp:conn_pool_interface", + "//include/envoy/upstream:cluster_manager_interface", + "//include/envoy/upstream:load_balancer_interface", + "//include/envoy/upstream:thread_local_cluster_interface", + "//source/common/common:logger_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_lib", + "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1:thrift_proxy_cc", + ], +) diff --git a/source/extensions/filters/network/thrift_proxy/router/config.cc b/source/extensions/filters/network/thrift_proxy/router/config.cc new file mode 100644 index 0000000000000..abf7cafbc2655 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/config.cc @@ -0,0 +1,34 @@ +#include "extensions/filters/network/thrift_proxy/router/config.h" + +#include "envoy/registry/registry.h" + +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +ThriftFilters::FilterFactoryCb RouterFilterConfig::createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::router::Router& proto_config, + const std::string& stat_prefix, Server::Configuration::FactoryContext& context) { + UNREFERENCED_PARAMETER(proto_config); + UNREFERENCED_PARAMETER(stat_prefix); + + return [&context](ThriftFilters::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addDecoderFilter(std::make_shared(context.clusterManager())); + }; +} + +/** + * Static registration for the router filter. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/config.h b/source/extensions/filters/network/thrift_proxy/router/config.h new file mode 100644 index 0000000000000..ae4629bfd0ada --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/config.h @@ -0,0 +1,32 @@ +#pragma once + +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.pb.h" +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.pb.validate.h" + +#include "extensions/filters/network/thrift_proxy/filters/factory_base.h" +#include "extensions/filters/network/thrift_proxy/filters/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +class RouterFilterConfig + : public ThriftFilters::FactoryBase< + envoy::extensions::filters::network::thrift_proxy::v2alpha1::router::Router> { +public: + RouterFilterConfig() : FactoryBase(ThriftFilters::ThriftFilterNames::get().ROUTER) {} + +private: + ThriftFilters::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::router::Router& + proto_config, + const std::string& stat_prefix, Server::Configuration::FactoryContext& context) override; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h new file mode 100644 index 0000000000000..32d717de52351 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +/** + * RouteEntry is an individual resolved route entry. + */ +class RouteEntry { +public: + virtual ~RouteEntry() {} + + /** + * @return const std::string& the upstream cluster that owns the route. + */ + virtual const std::string& clusterName() const PURE; +}; + +/** + * Route holds the RouteEntry for a request. + */ +class Route { +public: + virtual ~Route() {} + + /** + * @return the route entry or nullptr if there is no matching route for the request. + */ + virtual const RouteEntry* routeEntry() const PURE; +}; + +typedef std::shared_ptr RouteConstSharedPtr; + +/** + * The router configuration. + */ +class Config { +public: + virtual ~Config() {} + + /** + * Based on the incoming Thrift request transport and/or protocol data, determine the target + * route for the request. + * @param method supplies the thrift method name + * @return the route or nullptr if there is no matching route for the request. + */ + virtual RouteConstSharedPtr route(const std::string& method) const PURE; +}; + +typedef std::shared_ptr ConfigConstSharedPtr; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc new file mode 100644 index 0000000000000..eafd626af1b6f --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -0,0 +1,280 @@ +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/upstream/cluster_manager.h" +#include "envoy/upstream/thread_local_cluster.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +RouteEntryImplBase::RouteEntryImplBase( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route) + : cluster_name_(route.route().cluster()) {} + +const std::string& RouteEntryImplBase::clusterName() const { return cluster_name_; } + +const RouteEntry* RouteEntryImplBase::routeEntry() const { return this; } + +RouteConstSharedPtr RouteEntryImplBase::clusterEntry() const { return shared_from_this(); } + +MethodNameRouteEntryImpl::MethodNameRouteEntryImpl( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route) + : RouteEntryImplBase(route), method_name_(route.match().method()) {} + +RouteConstSharedPtr MethodNameRouteEntryImpl::matches(const std::string& method_name) const { + if (method_name_.empty() || method_name_ == method_name) { + return clusterEntry(); + } + + return nullptr; +} + +RouteMatcher::RouteMatcher( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration& config) { + for (const auto& route : config.routes()) { + routes_.emplace_back(new MethodNameRouteEntryImpl(route)); + } +} + +RouteConstSharedPtr RouteMatcher::route(const std::string& method_name) const { + for (const auto& route : routes_) { + RouteConstSharedPtr route_entry = route->matches(method_name); + if (nullptr != route_entry) { + return route_entry; + } + } + + return nullptr; +} + +void Router::onDestroy() { + if (upstream_request_ != nullptr) { + upstream_request_->resetStream(); + } + cleanup(); +} + +void Router::setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) { + callbacks_ = &callbacks; + + // TODO(zuercher): handle buffer limits +} + +void Router::resetUpstreamConnection() { + if (upstream_request_ != nullptr) { + upstream_request_->resetStream(); + } +} + +ThriftFilters::FilterStatus Router::transportBegin(absl::optional size) { + UNREFERENCED_PARAMETER(size); + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus Router::transportEnd() { + if (upstream_request_->msg_type_ == MessageType::Oneway) { + // No response expected + upstream_request_->onResponseComplete(); + cleanup(); + } + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus Router::messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) { + // TODO(zuercher): route stats (e.g., no_route, no_cluster, upstream_rq_maintenance_mode, no + // healtthy upstream) + + route_ = callbacks_->route(); + if (!route_) { + ENVOY_STREAM_LOG(debug, "no cluster match for method '{}'", *callbacks_, name); + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ + new AppException(name, seq_id, AppExceptionType::UnknownMethod, + fmt::format("no route for method '{}'", name))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + route_entry_ = route_->routeEntry(); + + Upstream::ThreadLocalCluster* cluster = cluster_manager_.get(route_entry_->clusterName()); + if (!cluster) { + ENVOY_STREAM_LOG(debug, "unknown cluster '{}'", *callbacks_, route_entry_->clusterName()); + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ + new AppException(name, seq_id, AppExceptionType::InternalError, + fmt::format("unknown cluster '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + cluster_ = cluster->info(); + ENVOY_STREAM_LOG(debug, "cluster '{}' match for method '{}'", *callbacks_, + route_entry_->clusterName(), name); + + if (cluster_->maintenanceMode()) { + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + name, seq_id, AppExceptionType::InternalError, + fmt::format("maintenance mode for cluster '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( + route_entry_->clusterName(), Upstream::ResourcePriority::Default, this); + if (!conn_pool) { + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + name, seq_id, AppExceptionType::InternalError, + fmt::format("no healthy upstream for '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + ENVOY_STREAM_LOG(debug, "router decoding request", *callbacks_); + + upstream_request_.reset(new UpstreamRequest(*this, *conn_pool, name, msg_type, seq_id)); + upstream_request_->start(); + return ThriftFilters::FilterStatus::StopIteration; +} + +ThriftFilters::FilterStatus Router::messageEnd() { + ProtocolConverter::messageEnd(); + + Buffer::OwnedImpl transport_buffer; + upstream_request_->transport_->encodeFrame(transport_buffer, upstream_request_buffer_); + upstream_request_->conn_data_->connection().write(transport_buffer, false); + upstream_request_->onRequestComplete(); + return ThriftFilters::FilterStatus::Continue; +} + +void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { + ASSERT(!upstream_request_->response_complete_); + + if (!upstream_request_->response_started_) { + callbacks_->startUpstreamResponse(upstream_request_->transport_->type(), protocolType()); + upstream_request_->response_started_ = true; + } + + if (callbacks_->upstreamData(data)) { + upstream_request_->onResponseComplete(); + cleanup(); + return; + } + + if (end_stream) { + // Response is incomplete, but no more data is coming. + upstream_request_->onResponseComplete(); + upstream_request_->onResetStream(Tcp::ConnectionPool::PoolFailureReason::ConnectionFailure); + cleanup(); + } +} + +const Network::Connection* Router::downstreamConnection() const { + if (callbacks_ != nullptr) { + return callbacks_->connection(); + } + + return nullptr; +} + +void Router::convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id) { + ProtocolConverter::messageBegin(absl::string_view(name), msg_type, seq_id); +} + +void Router::cleanup() { upstream_request_.reset(); } + +Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, + absl::string_view method_name, MessageType msg_type, + int32_t seq_id) + : parent_(parent), conn_pool_(pool), method_name_(std::string(method_name)), + msg_type_(msg_type), seq_id_(seq_id), request_complete_(false), response_started_(false), + response_complete_(false) {} + +Router::UpstreamRequest::~UpstreamRequest() {} + +void Router::UpstreamRequest::start() { + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(*this); + if (handle) { + conn_pool_handle_ = handle; + } +} + +void Router::UpstreamRequest::resetStream() { + if (conn_data_ != nullptr) { + conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); + conn_data_ = nullptr; + } +} + +void Router::UpstreamRequest::onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) { + // Mimic an upstream reset. + onUpstreamHostSelected(host); + onResetStream(reason); +} + +void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionData& conn_data, + Upstream::HostDescriptionConstSharedPtr host) { + onUpstreamHostSelected(host); + conn_data_ = &conn_data; + conn_data_->addUpstreamCallbacks(parent_); + + conn_pool_handle_ = nullptr; + + // TODO(zuercher): let cluster specify a specific transport and protocol + transport_ = + NamedTransportConfigFactory::getFactory(parent_.callbacks_->downstreamTransportType()) + .createTransport(); + + parent_.initProtocolConverter( + NamedProtocolConfigFactory::getFactory(parent_.callbacks_->downstreamProtocolType()) + .createProtocol(), + parent_.upstream_request_buffer_); + + // TODO(zuercher): need to use an upstream-connection-specific sequence id + parent_.convertMessageBegin(method_name_, msg_type_, seq_id_); + + parent_.callbacks_->continueDecoding(); +} + +void Router::UpstreamRequest::onRequestComplete() { request_complete_ = true; } + +void Router::UpstreamRequest::onResponseComplete() { + response_complete_ = true; + if (conn_data_ != nullptr) { + conn_data_->release(); + } + conn_data_ = nullptr; +} + +void Router::UpstreamRequest::onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host) { + upstream_host_ = host; +} + +void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReason reason) { + switch (reason) { + case Tcp::ConnectionPool::PoolFailureReason::Overflow: + parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + method_name_, seq_id_, AppExceptionType::InternalError, + fmt::format("too many connections to '{}'", upstream_host_->address()->asString()))}); + break; + case Tcp::ConnectionPool::PoolFailureReason::ConnectionFailure: + if (!response_started_) { + parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + method_name_, seq_id_, AppExceptionType::InternalError, + fmt::format("connection failure '{}'", upstream_host_->address()->asString()))}); + return; + } + + parent_.callbacks_->resetDownstreamConnection(); + break; + default: + NOT_REACHED; + } +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h new file mode 100644 index 0000000000000..340a2a66cc4c6 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -0,0 +1,157 @@ +#pragma once + +#include +#include +#include + +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/tcp/conn_pool.h" +#include "envoy/upstream/load_balancer.h" + +#include "common/common/logger.h" + +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +class RouteEntryImplBase : public RouteEntry, + public Route, + public std::enable_shared_from_this { +public: + RouteEntryImplBase( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route); + + // Router::RouteEntry + const std::string& clusterName() const override; + + // Router::Route + const RouteEntry* routeEntry() const override; + + virtual RouteConstSharedPtr matches(const std::string& method_name) const PURE; + +protected: + RouteConstSharedPtr clusterEntry() const; + +private: + const std::string cluster_name_; +}; + +typedef std::shared_ptr RouteEntryImplBaseConstSharedPtr; + +class MethodNameRouteEntryImpl : public RouteEntryImplBase { +public: + MethodNameRouteEntryImpl( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route); + + const std::string& methodName() const { return method_name_; } + + // RoutEntryImplBase + RouteConstSharedPtr matches(const std::string& method_name) const override; + +private: + const std::string method_name_; +}; + +class RouteMatcher { +public: + RouteMatcher( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration&); + + RouteConstSharedPtr route(const std::string& method_name) const; + +private: + std::vector routes_; +}; + +class Router : public Tcp::ConnectionPool::UpstreamCallbacks, + public Upstream::LoadBalancerContext, + public ProtocolConverter, + Logger::Loggable { +public: + Router(Upstream::ClusterManager& cluster_manager) : cluster_manager_(cluster_manager) {} + + ~Router() {} + + // ProtocolConverter + void onDestroy() override; + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) override; + void resetUpstreamConnection() override; + ThriftFilters::FilterStatus transportBegin(absl::optional size) override; + ThriftFilters::FilterStatus transportEnd() override; + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override; + ThriftFilters::FilterStatus messageEnd() override; + + // Upstream::LoadBalancerContext + absl::optional computeHashKey() override { return {}; } + const Envoy::Router::MetadataMatchCriteria* metadataMatchCriteria() override { return nullptr; } + const Network::Connection* downstreamConnection() const override; + const Http::HeaderMap* downstreamHeaders() const override { return nullptr; } + + // Tcp::ConnectionPool::UpstreamCallbacks + void onUpstreamData(Buffer::Instance& data, bool end_stream) override; + +private: + struct UpstreamRequest : public Tcp::ConnectionPool::Callbacks { + UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, + absl::string_view method_name, MessageType msg_type, int32_t seq_id); + ~UpstreamRequest(); + + void start(); + void resetStream(); + + // Tcp::ConnectionPool::Callbacks + void onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) override; + void onPoolReady(Tcp::ConnectionPool::ConnectionData& conn, + Upstream::HostDescriptionConstSharedPtr host) override; + + void onRequestComplete(); + void onResponseComplete(); + void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host); + void onResetStream(Tcp::ConnectionPool::PoolFailureReason reason); + + Router& parent_; + Tcp::ConnectionPool::Instance& conn_pool_; + const std::string method_name_; + const MessageType msg_type_; + const int32_t seq_id_; + + Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; + Tcp::ConnectionPool::ConnectionData* conn_data_{}; + Upstream::HostDescriptionConstSharedPtr upstream_host_; + TransportPtr transport_; + ProtocolType proto_type_{ProtocolType::Auto}; + + bool request_complete_ : 1; + bool response_started_ : 1; + bool response_complete_ : 1; + }; + + void convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id); + void cleanup(); + + Upstream::ClusterManager& cluster_manager_; + + ThriftFilters::DecoderFilterCallbacks* callbacks_{}; + RouteConstSharedPtr route_{}; + const RouteEntry* route_entry_{}; + Upstream::ClusterInfoConstSharedPtr cluster_; + + std::unique_ptr upstream_request_; + Buffer::OwnedImpl upstream_request_buffer_; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/stats.h b/source/extensions/filters/network/thrift_proxy/stats.h new file mode 100644 index 0000000000000..a630fc735d034 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/stats.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +#include "envoy/stats/stats.h" +#include "envoy/stats/stats_macros.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * All thrift filter stats. @see stats_macros.h + */ +// clang-format off +#define ALL_THRIFT_FILTER_STATS(COUNTER, GAUGE, HISTOGRAM) \ + COUNTER(request) \ + COUNTER(request_call) \ + COUNTER(request_oneway) \ + COUNTER(request_invalid_type) \ + GAUGE(request_active) \ + COUNTER(request_decoding_error) \ + HISTOGRAM(request_time_ms) \ + COUNTER(response) \ + COUNTER(response_reply) \ + COUNTER(response_success) \ + COUNTER(response_error) \ + COUNTER(response_exception) \ + COUNTER(response_invalid_type) \ + COUNTER(response_decoding_error) \ + COUNTER(cx_destroy_local_with_active_rq) \ + COUNTER(cx_destroy_remote_with_active_rq) +// clang-format on + +/** + * Struct definition for all mongo proxy stats. @see stats_macros.h + */ +struct ThriftFilterStats { + ALL_THRIFT_FILTER_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, GENERATE_HISTOGRAM_STRUCT) + + static ThriftFilterStats generateStats(const std::string& prefix, Stats::Scope& scope) { + return ThriftFilterStats{ALL_THRIFT_FILTER_STATS(POOL_COUNTER_PREFIX(scope, prefix), + POOL_GAUGE_PREFIX(scope, prefix), + POOL_HISTOGRAM_PREFIX(scope, prefix))}; + } +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/transport.h b/source/extensions/filters/network/thrift_proxy/transport.h index 516c055ee72b9..776f3a91d8d8b 100644 --- a/source/extensions/filters/network/thrift_proxy/transport.h +++ b/source/extensions/filters/network/thrift_proxy/transport.h @@ -4,7 +4,10 @@ #include #include "envoy/buffer/buffer.h" +#include "envoy/registry/registry.h" +#include "common/common/assert.h" +#include "common/config/utility.h" #include "common/singleton/const_singleton.h" #include "absl/types/optional.h" @@ -14,6 +17,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +enum class TransportType { + Framed, + Unframed, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST TRANSPORT TYPE + LastTransportType = Auto, + +}; + /** * Names of available Transport implementations. */ @@ -27,29 +40,23 @@ class TransportNameValues { // Auto-detection transport const std::string AUTO = "auto"; + + const std::string& fromType(TransportType type) const { + switch (type) { + case TransportType::Framed: + return FRAMED; + case TransportType::Unframed: + return UNFRAMED; + case TransportType::Auto: + return AUTO; + default: + NOT_REACHED; + } + } }; typedef ConstSingleton TransportNames; -/** - * TransportCallbacks are Thrift transport-level callbacks. - */ -class TransportCallbacks { -public: - virtual ~TransportCallbacks() {} - - /** - * Indicates the start of a Thrift transport frame was detected. - * @param size the size of the message, if available to the transport - */ - virtual void transportFrameStart(absl::optional size) PURE; - - /** - * Indicates the end of a Thrift transport frame was detected. - */ - virtual void transportFrameComplete() PURE; -}; - /** * Transport represents a Thrift transport. The Thrift transport is nominally a generic, * bi-directional byte stream. In Envoy we assume it always represents a network byte stream and @@ -66,20 +73,27 @@ class Transport { */ virtual const std::string& name() const PURE; + /** + * @return TransportType the transport type + */ + virtual TransportType type() const PURE; + /* - * decodeFrameStart decodes the start of a transport message, potentially invoking callbacks. - * If successful, the start of the frame is removed from the buffer. + * Decodes the start of a transport message. If successful, the start of the frame is removed + * from the buffer. * * @param buffer the currently buffered thrift data. + * @param size updated with the frame size on success. If frame size is not encoded, the size + * is cleared on success. * @return bool true if a complete frame header was successfully consumed, false if more data * is required. * @throws EnvoyException if the data is not valid for this transport. */ - virtual bool decodeFrameStart(Buffer::Instance& buffer) PURE; + virtual bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) PURE; /* - * decodeFrameEnd decodes the end of a transport message, potentially invoking callbacks. - * If successful, the end of the frame is removed from the buffer. + * Decodes the end of a transport message. If successful, the end of the frame is removed from + * the buffer. * * @param buffer the currently buffered thrift data. * @return bool true if a complete frame trailer was successfully consumed, false if more data @@ -89,8 +103,8 @@ class Transport { virtual bool decodeFrameEnd(Buffer::Instance& buffer) PURE; /** - * encodeFrame wraps the given message buffer with the transport's header and trailer (if any). - * After encoding, message will be empty. + * Wraps the given message buffer with the transport's header and trailer (if any). After + * encoding, message will be empty. * @param buffer is the output buffer * @param message a protocol-encoded message * @throws EnvoyException if the message is too large for the transport @@ -100,6 +114,52 @@ class Transport { typedef std::unique_ptr TransportPtr; +/** + * Implemented by each Thrift transport and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedTransportConfigFactory { +public: + virtual ~NamedTransportConfigFactory() {} + + /** + * Create a particular Thrift transport. + * @return TransportPtr the transport + */ + virtual TransportPtr createTransport() PURE; + + /** + * @return std::string the identifying name for a particular implementation of thrift transport + * produced by the factory. + */ + virtual std::string name() PURE; + + /** + * Convenience method to lookup a factory by type. + * @param TransportType the transport type + * @return NamedTransportConfigFactory& for the TransportType + */ + static NamedTransportConfigFactory& getFactory(TransportType type) { + const std::string& name = TransportNames::get().fromType(type); + return Envoy::Config::Utility::getAndCheckFactory(name); + } +}; + +/** + * TransportFactoryBase provides a template for a trivial NamedTransportConfigFactory. + */ +template class TransportFactoryBase : public NamedTransportConfigFactory { + TransportPtr createTransport() override { return std::move(std::make_unique()); } + + std::string name() override { return name_; } + +protected: + TransportFactoryBase(const std::string& name) : name_(name) {} + +private: + const std::string name_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.cc b/source/extensions/filters/network/thrift_proxy/transport_impl.cc index 4caa88efcb313..ff177884d94c5 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.cc @@ -15,7 +15,7 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { +bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) { if (transport_ == nullptr) { // Not enough data to select a transport. if (buffer.length() < 8) { @@ -30,13 +30,13 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { // is configurable, but defaults to 256 MB (0x1000000). THeaderTransport will take up to ~1GB // (0x3FFFFFFF) when it falls back to framed mode. if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { - setTransport(std::make_unique(callbacks_)); + setTransport(std::make_unique()); } } else { // Check for sane unframed protocol. proto_start = static_cast((size >> 16) & 0xFFFF); if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { - setTransport(std::make_unique(callbacks_)); + setTransport(std::make_unique()); } } @@ -51,7 +51,7 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { } } - return transport_->decodeFrameStart(buffer); + return transport_->decodeFrameStart(buffer, size); } bool AutoTransportImpl::decodeFrameEnd(Buffer::Instance& buffer) { @@ -64,6 +64,16 @@ void AutoTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& transport_->encodeFrame(buffer, message); } +class AutoTransportConfigFactory : public TransportFactoryBase { +public: + AutoTransportConfigFactory() : TransportFactoryBase(TransportNames::get().AUTO) {} +}; + +/** + * Static registration for the auto transport. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.h b/source/extensions/filters/network/thrift_proxy/transport_impl.h index 137a58a622c9f..08281c53c6d16 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.h @@ -15,33 +15,25 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/* - * TransportImplBase provides a base class for Transport implementations. - */ -class TransportImplBase : public virtual Transport { -public: - TransportImplBase(TransportCallbacks& callbacks) : callbacks_(callbacks) {} - -protected: - void onFrameStart(absl::optional size) const { callbacks_.transportFrameStart(size); } - void onFrameComplete() const { callbacks_.transportFrameComplete(); } - - TransportCallbacks& callbacks_; -}; - /** * AutoTransportImpl implements Transport and attempts to distinguish between the Thrift framed and * unframed transports. Once the transport is detected, subsequent operations are delegated to the * appropriate implementation. */ -class AutoTransportImpl : public TransportImplBase { +class AutoTransportImpl : public Transport { public: - AutoTransportImpl(TransportCallbacks& callbacks) - : TransportImplBase(callbacks), name_(TransportNames::get().AUTO){}; + AutoTransportImpl() : name_(TransportNames::get().AUTO){}; // Transport const std::string& name() const override { return name_; } - bool decodeFrameStart(Buffer::Instance& buffer) override; + TransportType type() const override { + if (transport_ != nullptr) { + return transport_->type(); + } + + return TransportType::Auto; + } + bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; bool decodeFrameEnd(Buffer::Instance& buffer) override; void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc new file mode 100644 index 0000000000000..d3a2744540c96 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc @@ -0,0 +1,22 @@ +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class UnframedTransportConfigFactory : public TransportFactoryBase { +public: + UnframedTransportConfigFactory() : TransportFactoryBase(TransportNames::get().UNFRAMED) {} +}; + +/** + * Static registration for the unframed transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h index af7ea884ddba0..29992dfc0812f 100644 --- a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h @@ -17,20 +17,18 @@ namespace ThriftProxy { * UnframedTransportImpl implements the Thrift Unframed transport. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md */ -class UnframedTransportImpl : public TransportImplBase { +class UnframedTransportImpl : public Transport { public: - UnframedTransportImpl(TransportCallbacks& callbacks) : TransportImplBase(callbacks) {} + UnframedTransportImpl() {} // Transport const std::string& name() const override { return TransportNames::get().UNFRAMED; } - bool decodeFrameStart(Buffer::Instance&) override { - onFrameStart(absl::optional()); - return true; - } - bool decodeFrameEnd(Buffer::Instance&) override { - onFrameComplete(); + TransportType type() const override { return TransportType::Unframed; } + bool decodeFrameStart(Buffer::Instance&, absl::optional& size) override { + size.reset(); return true; } + bool decodeFrameEnd(Buffer::Instance&) override { return true; } void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override { buffer.move(message); } diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index faa42cb93aff6..1eea2452232bb 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -19,7 +19,12 @@ envoy_extension_cc_mock( hdrs = ["mocks.h"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_lib", "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + "//test/mocks/network:network_mocks", "//test/test_common:printers_lib", ], ) @@ -79,33 +84,40 @@ envoy_extension_cc_test( extension_name = "envoy.filters.network.thrift_proxy", deps = [ "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy/router:config", "//test/mocks/server:server_mocks", ], ) envoy_extension_cc_test( - name = "decoder_test", - srcs = ["decoder_test.cc"], + name = "conn_manager_test", + srcs = ["conn_manager_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", ":utility_lib", - "//source/extensions/filters/network/thrift_proxy:decoder_lib", + "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + "//test/mocks/network:network_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/upstream:upstream_mocks", "//test/test_common:printers_lib", - "//test/test_common:utility_lib", ], ) envoy_extension_cc_test( - name = "filter_test", - srcs = ["filter_test.cc"], + name = "decoder_test", + srcs = ["decoder_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ + ":mocks", ":utility_lib", - "//source/common/stats:stats_lib", - "//source/extensions/filters/network/thrift_proxy:filter_lib", - "//test/mocks/network:network_mocks", + "//source/extensions/filters/network/thrift_proxy:decoder_lib", "//test/test_common:printers_lib", + "//test/test_common:utility_lib", ], ) @@ -117,7 +129,6 @@ envoy_extension_cc_test( ":mocks", ":utility_lib", "//source/extensions/filters/network/thrift_proxy:transport_lib", - "//test/mocks/buffer:buffer_mocks", "//test/test_common:printers_lib", "//test/test_common:utility_lib", ], @@ -136,6 +147,24 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "router_test", + srcs = ["router_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//source/extensions/filters/network/thrift_proxy/router:router_lib", + "//test/mocks/network:network_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/upstream:upstream_mocks", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + ], +) + envoy_extension_cc_test( name = "transport_impl_test", srcs = ["transport_impl_test.cc"], @@ -163,8 +192,8 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "filter_integration_test", - srcs = ["filter_integration_test.cc"], + name = "integration_test", + srcs = ["integration_test.cc"], data = [ "//test/extensions/filters/network/thrift_proxy/driver:generate_fixture", ], @@ -172,7 +201,8 @@ envoy_extension_cc_test( deps = [ "//source/extensions/filters/network/tcp_proxy:config", "//source/extensions/filters/network/thrift_proxy:config", - "//source/extensions/filters/network/thrift_proxy:filter_lib", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy/router:config", "//test/integration:integration_lib", "//test/test_common:environment_lib", "//test/test_common:network_utility_lib", diff --git a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc index 6ed1723fee094..3db9d588d5338 100644 --- a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc @@ -4,29 +4,24 @@ #include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(BinaryProtocolTest, Name) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary"); } TEST(BinaryProtocolTest, ReadMessageBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -97,7 +92,6 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addInt32(buffer, 0); addInt32(buffer, 1234); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 1234)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -139,7 +133,6 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addString(buffer, "the_name"); addInt32(buffer, 5678); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 5678)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -150,34 +143,29 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { TEST(BinaryProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; - EXPECT_CALL(cb, messageComplete()); EXPECT_TRUE(proto.readMessageEnd(buffer)); } TEST(BinaryProtocolTest, ReadStructBegin) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; std::string name = "-"; - EXPECT_CALL(cb, structBegin(absl::string_view(""))); + EXPECT_TRUE(proto.readStructBegin(buffer, name)); EXPECT_EQ(name, ""); } TEST(BinaryProtocolTest, ReadStructEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); - EXPECT_CALL(cb, structEnd()); + BinaryProtocolImpl proto; + EXPECT_TRUE(proto.readStructEnd(buffer)); } TEST(BinaryProtocolTest, ReadFieldBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -201,7 +189,6 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { addInt8(buffer, FieldType::Stop); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Stop, 0)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Stop); @@ -234,7 +221,6 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { addInt8(buffer, FieldType::I32); addInt16(buffer, 99); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 99)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -263,14 +249,12 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { TEST(BinaryProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } TEST(BinaryProtocolTest, ReadMapBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -328,14 +312,12 @@ TEST(BinaryProtocolTest, ReadMapBegin) { TEST(BinaryProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } TEST(BinaryProtocolTest, ReadListBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -385,14 +367,12 @@ TEST(BinaryProtocolTest, ReadListBegin) { TEST(BinaryProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } TEST(BinaryProtocolTest, ReadSetBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() Buffer::OwnedImpl buffer; @@ -410,14 +390,12 @@ TEST(BinaryProtocolTest, ReadSetBegin) { TEST(BinaryProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } TEST(BinaryProtocolTest, ReadIntegerTypes) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Bool { @@ -535,8 +513,7 @@ TEST(BinaryProtocolTest, ReadIntegerTypes) { } TEST(BinaryProtocolTest, ReadDouble) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -564,8 +541,7 @@ TEST(BinaryProtocolTest, ReadDouble) { } TEST(BinaryProtocolTest, ReadString) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data to read length { @@ -632,8 +608,7 @@ TEST(BinaryProtocolTest, ReadString) { TEST(BinaryProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; std::string value = "-"; @@ -646,8 +621,7 @@ TEST(BinaryProtocolTest, ReadBinary) { } TEST(BinaryProtocolTest, WriteMessageBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Named call { @@ -665,32 +639,28 @@ TEST(BinaryProtocolTest, WriteMessageBegin) { } TEST(BinaryProtocolTest, WriteMessageEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMessageEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteStructBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructBegin(buffer, "unused"); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteStructEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteFieldBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Stop field { @@ -708,16 +678,14 @@ TEST(BinaryProtocolTest, WriteFieldBegin) { } TEST(BinaryProtocolTest, WriteFieldEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteMapBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Non-empty map { @@ -743,16 +711,14 @@ TEST(BinaryProtocolTest, WriteMapBegin) { } TEST(BinaryProtocolTest, WriteMapEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMapEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteListBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Non-empty list { @@ -777,16 +743,14 @@ TEST(BinaryProtocolTest, WriteListBegin) { } TEST(BinaryProtocolTest, WriteListEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeListEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteSetBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Only test the happy path, as this shares an implementation with writeListBegin // Non-empty list @@ -796,16 +760,14 @@ TEST(BinaryProtocolTest, WriteSetBegin) { } TEST(BinaryProtocolTest, WriteSetEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeSetEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteBool) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // True { @@ -823,8 +785,7 @@ TEST(BinaryProtocolTest, WriteBool) { } TEST(BinaryProtocolTest, WriteByte) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -840,8 +801,7 @@ TEST(BinaryProtocolTest, WriteByte) { } TEST(BinaryProtocolTest, WriteInt16) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -857,8 +817,7 @@ TEST(BinaryProtocolTest, WriteInt16) { } TEST(BinaryProtocolTest, WriteInt32) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -874,8 +833,7 @@ TEST(BinaryProtocolTest, WriteInt32) { } TEST(BinaryProtocolTest, WriteInt64) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -891,16 +849,14 @@ TEST(BinaryProtocolTest, WriteInt64) { } TEST(BinaryProtocolTest, WriteDouble) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeDouble(buffer, 3.0); EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); } TEST(BinaryProtocolTest, WriteString) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -919,8 +875,7 @@ TEST(BinaryProtocolTest, WriteString) { } TEST(BinaryProtocolTest, WriteBinary) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Happy path only, since this is just a synonym for writeString Buffer::OwnedImpl buffer; @@ -932,14 +887,12 @@ TEST(BinaryProtocolTest, WriteBinary) { } TEST(LaxBinaryProtocolTest, Name) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary/non-strict"); } TEST(LaxBinaryProtocolTest, ReadMessageBegin) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; // Insufficient data { @@ -989,7 +942,6 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { addInt8(buffer, MessageType::Call); addInt32(buffer, 1234); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 1234)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -1027,7 +979,6 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { addInt8(buffer, MessageType::Call); addInt32(buffer, 5678); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 5678)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -1037,8 +988,7 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { } TEST(LaxBinaryProtocolTest, WriteMessageBegin) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; // Named call { diff --git a/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc b/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc index 9536177a96552..26030cd7bd9a7 100644 --- a/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc +++ b/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc @@ -17,50 +17,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(BufferWrapperTest, ImplementedFunctions) { - Buffer::OwnedImpl buffer; - addString(buffer, "abcdefghij"); - - BufferWrapper wrapper(buffer); - { - char s[4] = {0}; - wrapper.copyOut(0, 3, s); - EXPECT_EQ("abc", std::string(s)); - EXPECT_EQ(10, wrapper.length()); - EXPECT_EQ(0, wrapper.position()); - } - - { - char s[6] = {0}; - wrapper.copyOut(5, 5, s); - EXPECT_EQ("fghij", std::string(s)); - EXPECT_EQ(10, wrapper.length()); - EXPECT_EQ(0, wrapper.position()); - } - - { - std::string s(static_cast(wrapper.linearize(5)), 5); - EXPECT_EQ("abcde", s); - EXPECT_EQ(0, wrapper.position()); - } - - wrapper.drain(2); - - { - char s[4] = {0}; - wrapper.copyOut(4, 3, s); - EXPECT_EQ("ghi", std::string(s)); - EXPECT_EQ(8, wrapper.length()); - EXPECT_EQ(2, wrapper.position()); - } - - { - std::string s(static_cast(wrapper.linearize(8)), 8); - EXPECT_EQ("cdefghij", s); - EXPECT_EQ(2, wrapper.position()); - } -} - TEST(BufferHelperTest, PeekI8) { { Buffer::OwnedImpl buffer; diff --git a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc index b784d85e85297..79187821def52 100644 --- a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc @@ -4,15 +4,12 @@ #include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" #include "gtest/gtest.h" -using testing::NiceMock; -using testing::StrictMock; using testing::TestWithParam; using testing::Values; @@ -22,14 +19,12 @@ namespace NetworkFilters { namespace ThriftProxy { TEST(CompactProtocolTest, Name) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_EQ(proto.name(), "compact"); } TEST(CompactProtocolTest, ReadMessageBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -170,7 +165,6 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addInt8(buffer, 32); addInt8(buffer, 0); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 32)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -228,7 +222,6 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addInt8(buffer, 8); addString(buffer, "the_name"); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 0x0102)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -239,22 +232,19 @@ TEST(CompactProtocolTest, ReadMessageBegin) { TEST(CompactProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); - EXPECT_CALL(cb, messageComplete()); + CompactProtocolImpl proto; + EXPECT_TRUE(proto.readMessageEnd(buffer)); } TEST(CompactProtocolTest, ReadStruct) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; std::string name = "-"; - EXPECT_CALL(cb, structBegin(absl::string_view(""))); + EXPECT_TRUE(proto.readStructBegin(buffer, name)); EXPECT_EQ(name, ""); - EXPECT_CALL(cb, structEnd()); EXPECT_TRUE(proto.readStructEnd(buffer)); EXPECT_THROW_WITH_MESSAGE(proto.readStructEnd(buffer), EnvoyException, @@ -262,8 +252,7 @@ TEST(CompactProtocolTest, ReadStruct) { } TEST(CompactProtocolTest, ReadFieldBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -287,7 +276,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0xF0); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Stop, 0)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Stop); @@ -400,7 +388,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0x05); addInt8(buffer, 0x04); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 2)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -417,7 +404,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0xF5); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 17)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -428,14 +414,12 @@ TEST(CompactProtocolTest, ReadFieldBegin) { TEST(CompactProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } TEST(CompactProtocolTest, ReadMapBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -575,14 +559,12 @@ TEST(CompactProtocolTest, ReadMapBegin) { TEST(CompactProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } TEST(CompactProtocolTest, ReadListBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -690,14 +672,12 @@ TEST(CompactProtocolTest, ReadListBegin) { TEST(CompactProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } TEST(CompactProtocolTest, ReadSetBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() Buffer::OwnedImpl buffer; @@ -714,14 +694,12 @@ TEST(CompactProtocolTest, ReadSetBegin) { TEST(CompactProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } TEST(CompactProtocolTest, ReadBool) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Bool field values are encoded in the field type { @@ -734,7 +712,6 @@ TEST(CompactProtocolTest, ReadBool) { addInt8(buffer, 0x01); addInt8(buffer, 0x04); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Bool, 2)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Bool); @@ -751,7 +728,6 @@ TEST(CompactProtocolTest, ReadBool) { addInt8(buffer, 0x02); addInt8(buffer, 0x06); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Bool, 3)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Bool); @@ -787,8 +763,7 @@ TEST(CompactProtocolTest, ReadBool) { } TEST(CompactProtocolTest, ReadIntegerTypes) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Byte { @@ -919,8 +894,7 @@ TEST(CompactProtocolTest, ReadIntegerTypes) { } TEST(CompactProtocolTest, ReadDouble) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -950,8 +924,7 @@ TEST(CompactProtocolTest, ReadDouble) { } TEST(CompactProtocolTest, ReadString) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -1028,8 +1001,7 @@ TEST(CompactProtocolTest, ReadString) { TEST(CompactProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; std::string value = "-"; @@ -1046,8 +1018,7 @@ class CompactProtocolFieldTypeTest : public TestWithParam {}; TEST_P(CompactProtocolFieldTypeTest, ConvertsToFieldType) { uint8_t compact_field_type = GetParam(); - NiceMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; std::string name = "-"; int8_t invalid_field_type = static_cast(FieldType::LastFieldType) + 1; FieldType field_type = static_cast(invalid_field_type); @@ -1080,8 +1051,7 @@ INSTANTIATE_TEST_CASE_P(CompactFieldTypes, CompactProtocolFieldTypeTest, Values(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); TEST(CompactProtocolTest, WriteMessageBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Named call { @@ -1099,16 +1069,14 @@ TEST(CompactProtocolTest, WriteMessageBegin) { } TEST(CompactProtocolTest, WriteMessageEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMessageEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteStruct) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructBegin(buffer, "unused"); @@ -1123,16 +1091,14 @@ TEST(CompactProtocolTest, WriteStruct) { TEST(CompactProtocolTest, WriteFieldBegin) { // Stop field { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Stop, 1); EXPECT_EQ(std::string("\0", 1), buffer.toString()); } { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Short form { @@ -1145,7 +1111,7 @@ TEST(CompactProtocolTest, WriteFieldBegin) { { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Struct, 17); - EXPECT_EQ(std::string("\xC\0\x11", 3), buffer.toString()); + EXPECT_EQ(std::string("\xC\x22", 2), buffer.toString()); } // Short form @@ -1164,14 +1130,13 @@ TEST(CompactProtocolTest, WriteFieldBegin) { } { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Long form { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::I32, 16); - EXPECT_EQ(std::string("\x5\0\x10", 3), buffer.toString()); + EXPECT_EQ(std::string("\x5\x20", 2), buffer.toString()); } // Short form @@ -1185,21 +1150,20 @@ TEST(CompactProtocolTest, WriteFieldBegin) { { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Byte, 33); - EXPECT_EQ(std::string("\x3\0\x21", 3), buffer.toString()); + EXPECT_EQ(std::string("\x3\x42", 2), buffer.toString()); } - // Long form + // Long form (3 bytes) { Buffer::OwnedImpl buffer; - proto.writeFieldBegin(buffer, "unused", FieldType::String, 49); - EXPECT_EQ(std::string("\x8\0\x31", 3), buffer.toString()); + proto.writeFieldBegin(buffer, "unused", FieldType::String, 64); + EXPECT_EQ(std::string("\x8\x80\x1", 3), buffer.toString()); } } // Unknown field type { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; int8_t invalid_field_type = static_cast(FieldType::LastFieldType) + 1; @@ -1212,8 +1176,7 @@ TEST(CompactProtocolTest, WriteFieldBegin) { } TEST(CompactProtocolTest, WriteFieldEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldEnd(buffer); EXPECT_EQ(0, buffer.length()); @@ -1224,8 +1187,7 @@ TEST(CompactProtocolTest, WriteBoolField) { // Short form field { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 8); @@ -1245,15 +1207,14 @@ TEST(CompactProtocolTest, WriteBoolField) { // Long form field { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 16); EXPECT_EQ(0, buffer.length()); proto.writeBool(buffer, true); - EXPECT_EQ(std::string("\x1\0\x10", 3), buffer.toString()); + EXPECT_EQ(std::string("\x1\x20", 2), buffer.toString()); } { @@ -1261,14 +1222,13 @@ TEST(CompactProtocolTest, WriteBoolField) { proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 32); EXPECT_EQ(0, buffer.length()); proto.writeBool(buffer, false); - EXPECT_EQ(std::string("\x2\0\x20", 3), buffer.toString()); + EXPECT_EQ(std::string("\x2\x40", 2), buffer.toString()); } } } TEST(CompactProtocolTest, WriteMapBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Empty map { @@ -1294,16 +1254,14 @@ TEST(CompactProtocolTest, WriteMapBegin) { } TEST(CompactProtocolTest, WriteMapEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMapEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteListBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Empty list { @@ -1335,16 +1293,14 @@ TEST(CompactProtocolTest, WriteListBegin) { } TEST(CompactProtocolTest, WriteListEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeListEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteSetBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Empty set only, as writeSetBegin delegates to writeListBegin. Buffer::OwnedImpl buffer; @@ -1353,16 +1309,14 @@ TEST(CompactProtocolTest, WriteSetBegin) { } TEST(CompactProtocolTest, WriteSetEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeSetEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteBool) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Non-field bools (see WriteBoolField test) { @@ -1379,8 +1333,7 @@ TEST(CompactProtocolTest, WriteBool) { } TEST(CompactProtocolTest, WriteByte) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -1396,8 +1349,7 @@ TEST(CompactProtocolTest, WriteByte) { } TEST(CompactProtocolTest, WriteInt16) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // zigzag(1) = 2 { @@ -1436,8 +1388,7 @@ TEST(CompactProtocolTest, WriteInt16) { } TEST(CompactProtocolTest, WriteInt32) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // zigzag(1) = 2 { @@ -1476,8 +1427,7 @@ TEST(CompactProtocolTest, WriteInt32) { } TEST(CompactProtocolTest, WriteInt64) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // zigzag(1) = 2 { @@ -1516,16 +1466,14 @@ TEST(CompactProtocolTest, WriteInt64) { } TEST(CompactProtocolTest, WriteDouble) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeDouble(buffer, 3.0); EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); } TEST(CompactProtocolTest, WriteString) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -1551,8 +1499,7 @@ TEST(CompactProtocolTest, WriteString) { } TEST(CompactProtocolTest, WriteBinary) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // writeBinary is an alias for writeString Buffer::OwnedImpl buffer; diff --git a/test/extensions/filters/network/thrift_proxy/config_test.cc b/test/extensions/filters/network/thrift_proxy/config_test.cc index 3047282f9c771..56c481075a01f 100644 --- a/test/extensions/filters/network/thrift_proxy/config_test.cc +++ b/test/extensions/filters/network/thrift_proxy/config_test.cc @@ -31,7 +31,7 @@ TEST(ThriftFilterConfigTest, ValidProtoConfiguration) { ThriftProxyFilterConfigFactory factory; Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); Network::MockConnection connection; - EXPECT_CALL(connection, addFilter(_)); + EXPECT_CALL(connection, addReadFilter(_)); cb(connection); } @@ -45,7 +45,7 @@ TEST(ThriftFilterConfigTest, ThriftProxyWithEmptyProto) { Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); Network::MockConnection connection; - EXPECT_CALL(connection, addFilter(_)); + EXPECT_CALL(connection, addReadFilter(_)); cb(connection); } diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc new file mode 100644 index 0000000000000..f3cf9ff8444c7 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -0,0 +1,723 @@ +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" + +#include "common/buffer/buffer_impl.h" +#include "common/stats/stats_impl.h" + +#include "extensions/filters/network/thrift_proxy/buffer_helper.h" +#include "extensions/filters/network/thrift_proxy/config.h" +#include "extensions/filters/network/thrift_proxy/conn_manager.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Invoke; +using testing::NiceMock; +using testing::Return; +using testing::ReturnRef; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class TestConfigImpl : public ConfigImpl { +public: + TestConfigImpl( + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy proto_config, + Server::Configuration::MockFactoryContext& context, + ThriftFilters::DecoderFilterSharedPtr decoder_filter, ThriftFilterStats& stats) + : ConfigImpl(proto_config, context), decoder_filter_(decoder_filter), stats_(stats) {} + + // ConfigImpl + ThriftFilterStats& stats() override { return stats_; } + void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { + callbacks.addDecoderFilter(decoder_filter_); + } + +private: + ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ThriftFilterStats& stats_; +}; + +class ThriftConnectionManagerTest : public testing::Test { +public: + ThriftConnectionManagerTest() : stats_(ThriftFilterStats::generateStats("test.", store_)) {} + ~ThriftConnectionManagerTest() { + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + void initializeFilter() { initializeFilter(""); } + + void initializeFilter(const std::string& yaml) { + // Destroy any existing filter first. + filter_ = nullptr; + + for (auto counter : store_.counters()) { + counter->reset(); + } + + if (yaml.empty()) { + proto_config_.set_stat_prefix("test"); + } else { + MessageUtil::loadFromYaml(yaml, proto_config_); + MessageUtil::validate(proto_config_); + } + + proto_config_.set_stat_prefix("test"); + + decoder_filter_.reset(new NiceMock()); + config_.reset(new TestConfigImpl(proto_config_, context_, decoder_filter_, stats_)); + + filter_.reset(new ConnectionManager(*config_)); + filter_->initializeReadFilterCallbacks(filter_callbacks_); + filter_->onNewConnection(); + + // NOP currently. + filter_->onAboveWriteBufferHighWatermark(); + filter_->onBelowWriteBufferLowWatermark(); + } + + void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", msg_type, seq_id); + proto->writeStructBegin(msg, "response"); + proto->writeFieldBegin(msg, "success", FieldType::String, 0); + proto->writeString(msg, "field"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writeComplexFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, + int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", msg_type, seq_id); + proto->writeStructBegin(msg, "wrapper"); // call args struct or response struct + proto->writeFieldBegin(msg, "wrapper_field", FieldType::Struct, 0); // call arg/response success + + proto->writeStructBegin(msg, "payload"); + proto->writeFieldBegin(msg, "f1", FieldType::Bool, 1); + proto->writeBool(msg, true); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f2", FieldType::Byte, 2); + proto->writeByte(msg, 2); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f3", FieldType::Double, 3); + proto->writeDouble(msg, 3.0); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f4", FieldType::I16, 4); + proto->writeInt16(msg, 4); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f5", FieldType::I32, 5); + proto->writeInt32(msg, 5); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f6", FieldType::I64, 6); + proto->writeInt64(msg, 6); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f7", FieldType::String, 7); + proto->writeString(msg, "seven"); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f8", FieldType::Map, 8); + proto->writeMapBegin(msg, FieldType::I32, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeInt32(msg, 8); + proto->writeMapEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f9", FieldType::List, 9); + proto->writeListBegin(msg, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeListEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f10", FieldType::Set, 10); + proto->writeSetBegin(msg, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeSetEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); // payload stop field + proto->writeStructEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); // wrapper stop field + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, + int32_t seq_id, bool start) { + Buffer::OwnedImpl frame; + writeFramedBinaryMessage(frame, msg_type, seq_id); + + if (start) { + buffer.move(frame, 27); + } else { + frame.drain(27); + buffer.move(frame); + } + } + + void writeFramedBinaryTApplicationException(Buffer::Instance& buffer, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", MessageType::Exception, seq_id); + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::String, 1); + proto->writeString(msg, "error"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::I32, 2); + proto->writeInt32(msg, 1); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", MessageType::Reply, seq_id); + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::Struct, 2); + + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::String, 1); + proto->writeString(msg, "err"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + NiceMock context_; + std::shared_ptr decoder_filter_; + Stats::IsolatedStoreImpl store_; + ThriftFilterStats stats_; + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy proto_config_; + + std::unique_ptr config_; + + Buffer::OwnedImpl buffer_; + Buffer::OwnedImpl write_buffer_; + std::unique_ptr filter_; + NiceMock filter_callbacks_; +}; + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftOneWay) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesStopIterationAndResume) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + // Nothing further happens: we're stopped. + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_EQ(1, callbacks->streamId()); + EXPECT_EQ(TransportType::Framed, callbacks->downstreamTransportType()); + EXPECT_EQ(ProtocolType::Binary, callbacks->downstreamProtocolType()); + EXPECT_EQ(&filter_callbacks_.connection_, callbacks->connection()); + + // Resume processing. + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + callbacks->continueDecoding(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesFrameSplitAcrossBuffers) { + initializeFilter(); + + writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + // Complete the buffer + writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesInvalidMsgType) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Reply, 0x0F); // reply is invalid for a request + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(1U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolError) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1f, // framed: 31 bytes + 0x80, 0x01, 0x00, 0x01, // binary, call + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x08, 0xff, 0xff // illegal field id + }); + + std::string err = "invalid binary protocol field id -1"; + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x42, // framed: 66 bytes + 0x80, 0x01, 0x00, 0x03, // binary, exception + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x0b, 0x00, 0x01, // begin string field + }); + addInt32(write_buffer_, err.length()); + addString(write_buffer_, err); + addSeq(write_buffer_, { + 0x08, 0x00, 0x02, // begin i32 field + 0x00, 0x00, 0x00, 0x07, // protocol error + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ(bufferToString(write_buffer_), bufferToString(buffer)); + })); + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolErrorDuringMessageBegin) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0xff, // binary, invalid type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnEvent) { + // No active calls + { + initializeFilter(); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + EXPECT_EQ(0U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + } + + // Close mid-request + { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0x01, // binary proto, call type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x0F, // seq id + }); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + filter_->onEvent(Network::ConnectionEvent::LocalClose); + EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + } + + // Close before response + { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + filter_->onEvent(Network::ConnectionEvent::LocalClose); + EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + } +} + +TEST_F(ThriftConnectionManagerTest, Routing) { + const std::string yaml = R"EOF( +transport: FRAMED +protocol: BINARY +stat_prefix: test +route_config: + name: "routes" + routes: + - match: + method: name + route: + cluster: cluster +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + Router::RouteConstSharedPtr route = callbacks->route(); + EXPECT_NE(nullptr, route); + EXPECT_NE(nullptr, route->routeEntry()); + EXPECT_EQ("cluster", route->routeEntry()->clusterName()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + callbacks->continueDecoding(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { + initializeFilter(); + writeComplexFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(1U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryTApplicationException(write_buffer_, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryIDLException(write_buffer_, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(1U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // Call is not valid in a response + writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // illegal field id + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x1f, // framed: 31 bytes + 0x80, 0x01, 0x00, 0x02, // binary, reply + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x08, 0xff, 0xff // illegal field id + }); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(*decoder_filter_, resetUpstreamConnection()); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x01); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x02); + + std::list callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillRepeatedly(Invoke( + [&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks.push_back(&cb); })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(2U, store_.gauge("test.request_active").value()); + EXPECT_EQ(2U, store_.counter("test.request").value()); + EXPECT_EQ(2U, store_.counter("test.request_call").value()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(2); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); + callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); + callbacks.pop_front(); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x02); + callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); + callbacks.pop_front(); + EXPECT_EQ(2U, store_.counter("test.response").value()); + EXPECT_EQ(2U, store_.counter("test.response_reply").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + callbacks->resetDownstreamConnection(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 8d19f4e2b45ba..3417c4ca0219b 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -7,6 +7,7 @@ #include "test/test_common/printers.h" #include "test/test_common/utility.h" +#include "absl/strings/string_view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -33,48 +34,96 @@ namespace NetworkFilters { namespace ThriftProxy { namespace { -Expectation expectValue(NiceMock& proto, FieldType field_type, bool result = true) { +ExpectationSet expectValue(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type, bool result = true) { + ExpectationSet s; switch (field_type) { case FieldType::Bool: - return EXPECT_CALL(proto, readBool(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readBool(_, _)).WillOnce(Return(result)); + if (result) { + s += + EXPECT_CALL(filter, boolValue(_)).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::Byte: - return EXPECT_CALL(proto, readByte(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readByte(_, _)).WillOnce(Return(result)); + if (result) { + s += + EXPECT_CALL(filter, byteValue(_)).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::Double: - return EXPECT_CALL(proto, readDouble(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readDouble(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, doubleValue(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I16: - return EXPECT_CALL(proto, readInt16(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt16(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int16Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I32: - return EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int32Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I64: - return EXPECT_CALL(proto, readInt64(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt64(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int64Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::String: - return EXPECT_CALL(proto, readString(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readString(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, stringValue(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; default: NOT_REACHED; } + return s; } -ExpectationSet expectContainerStart(NiceMock& proto, FieldType field_type, - FieldType inner_type) { +ExpectationSet expectContainerStart(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type, FieldType inner_type) { ExpectationSet s; switch (field_type) { case FieldType::Struct: s += EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); s += EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(inner_type), SetArgReferee<3>(1), Return(true))); + s += EXPECT_CALL(filter, fieldBegin(absl::string_view(), inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::List: s += EXPECT_CALL(proto, readListBegin(_, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(1), Return(true))); + s += EXPECT_CALL(filter, listBegin(inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Map: s += EXPECT_CALL(proto, readMapBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(inner_type), SetArgReferee<3>(1), Return(true))); + s += EXPECT_CALL(filter, mapBegin(inner_type, inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Set: s += EXPECT_CALL(proto, readSetBegin(_, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(1), Return(true))); + s += EXPECT_CALL(filter, setBegin(inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; default: NOT_REACHED; @@ -82,23 +131,29 @@ ExpectationSet expectContainerStart(NiceMock& proto, FieldType fie return s; } -ExpectationSet expectContainerEnd(NiceMock& proto, FieldType field_type) { +ExpectationSet expectContainerEnd(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type) { ExpectationSet s; switch (field_type) { case FieldType::Struct: s += EXPECT_CALL(proto, readFieldEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); s += EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); s += EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::List: s += EXPECT_CALL(proto, readListEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, listEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Map: s += EXPECT_CALL(proto, readMapEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, mapEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Set: s += EXPECT_CALL(proto, readSetEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, setEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; default: NOT_REACHED; @@ -153,7 +208,8 @@ TEST_P(DecoderStateMachineNonValueTest, NoData) { ProtocolState state = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; - DecoderStateMachine dsm(proto); + StrictMock filter; + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(state); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), state); @@ -164,17 +220,18 @@ TEST_P(DecoderStateMachineValueTest, NoFieldValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type, false); - expectValue(proto, field_type, true); + expectValue(proto, filter, field_type, false); + expectValue(proto, filter, field_type, true); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -188,18 +245,19 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -209,13 +267,14 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { TEST(DecoderStateMachineTest, NoListValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -225,13 +284,14 @@ TEST(DecoderStateMachineTest, NoListValueData) { TEST(DecoderStateMachineTest, EmptyList) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -242,16 +302,17 @@ TEST_P(DecoderStateMachineValueTest, ListValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -262,18 +323,19 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, field_type); + expectValue(proto, filter, field_type); } EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -283,6 +345,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { TEST(DecoderStateMachineTest, NoMapKeyData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -290,7 +353,7 @@ TEST(DecoderStateMachineTest, NoMapKeyData) { SetArgReferee<3>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -300,6 +363,7 @@ TEST(DecoderStateMachineTest, NoMapKeyData) { TEST(DecoderStateMachineTest, NoMapValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -308,7 +372,7 @@ TEST(DecoderStateMachineTest, NoMapValueData) { EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(proto, readString(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -318,6 +382,7 @@ TEST(DecoderStateMachineTest, NoMapValueData) { TEST(DecoderStateMachineTest, EmptyMap) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -325,7 +390,7 @@ TEST(DecoderStateMachineTest, EmptyMap) { SetArgReferee<3>(0), Return(true))); EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -336,18 +401,19 @@ TEST_P(DecoderStateMachineValueTest, MapKeyValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type); // key - expectValue(proto, FieldType::String); // value + expectValue(proto, filter, field_type); // key + expectValue(proto, filter, FieldType::String); // value EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -358,18 +424,19 @@ TEST_P(DecoderStateMachineValueTest, MapValueValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, FieldType::I32); // key - expectValue(proto, field_type); // value + expectValue(proto, filter, FieldType::I32); // key + expectValue(proto, filter, field_type); // value EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -380,6 +447,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -387,13 +455,13 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { SetArgReferee<3>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, FieldType::I32); // key - expectValue(proto, field_type); // value + expectValue(proto, filter, FieldType::I32); // key + expectValue(proto, filter, field_type); // value } EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -403,13 +471,14 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { TEST(DecoderStateMachineTest, NoSetValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -419,13 +488,14 @@ TEST(DecoderStateMachineTest, NoSetValueData) { TEST(DecoderStateMachineTest, EmptySet) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -436,16 +506,17 @@ TEST_P(DecoderStateMachineValueTest, SetValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -456,18 +527,19 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, field_type); + expectValue(proto, filter, field_type); } EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -477,6 +549,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { TEST(DecoderStateMachineTest, EmptyStruct) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) @@ -488,7 +561,7 @@ TEST(DecoderStateMachineTest, EmptyStruct) { EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -498,24 +571,39 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -524,6 +612,7 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { TEST(DecoderStateMachineTest, MultiFieldStruct) { Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; std::vector field_types = {FieldType::Bool, FieldType::Byte, FieldType::Double, @@ -533,24 +622,36 @@ TEST(DecoderStateMachineTest, MultiFieldStruct) { EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); int16_t field_id = 1; for (FieldType field_type : field_types) { EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id++), Return(true))); + .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, field_id)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + field_id++; - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); } EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -562,35 +663,41 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; // start of message and outermost struct EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); - expectContainerStart(proto, FieldType::Struct, outer_field_type); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + expectContainerStart(proto, filter, FieldType::Struct, outer_field_type); - expectContainerStart(proto, outer_field_type, inner_type); + expectContainerStart(proto, filter, outer_field_type, inner_type); int outer_reps = outer_field_type == FieldType::Map ? 2 : 1; for (int i = 0; i < outer_reps; i++) { - expectContainerStart(proto, inner_type, value_type); + expectContainerStart(proto, filter, inner_type, value_type); int inner_reps = inner_type == FieldType::Map ? 2 : 1; for (int j = 0; j < inner_reps; j++) { - expectValue(proto, value_type); + expectValue(proto, filter, value_type); } - expectContainerEnd(proto, inner_type); + expectContainerEnd(proto, filter, inner_type); } - expectContainerEnd(proto, outer_field_type); + expectContainerEnd(proto, filter, outer_field_type); // end of message and outermost struct - expectContainerEnd(proto, FieldType::Struct); + expectContainerEnd(proto, filter, FieldType::Struct); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -599,63 +706,135 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { TEST(DecoderTest, OnData) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(filter, transportBegin(absl::optional(100))) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - decoder.onData(buffer); + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); } TEST(DecoderTest, OnDataResumes) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; + buffer.add("x"); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); - decoder.onData(buffer); + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); // buffer.length() == 1 +} + +TEST(DecoderTest, OnDataResumesTransportFrameStart) { + StrictMock* transport = new StrictMock(); + StrictMock* proto = new StrictMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + EXPECT_CALL(*transport, name()).Times(AnyNumber()); + EXPECT_CALL(*proto, name()).Times(AnyNumber()); + + InSequence dummy; + + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = false; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(false))); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), + SetArgReferee<3>(100), Return(true))); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(false)); - decoder.onData(buffer); + + underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 } TEST(DecoderTest, OnDataResumesTransportFrameEnd) { StrictMock* transport = new StrictMock(); StrictMock* proto = new StrictMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); EXPECT_CALL(*transport, name()).Times(AnyNumber()); EXPECT_CALL(*proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); @@ -665,11 +844,90 @@ TEST(DecoderTest, OnDataResumesTransportFrameEnd) { EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(false)); - decoder.onData(buffer); + + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(false)); - decoder.onData(buffer); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 +} + +TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { + + StrictMock* transport = new StrictMock(); + EXPECT_CALL(*transport, name()).WillRepeatedly(ReturnRef(transport->name_)); + + StrictMock* proto = new StrictMock(); + EXPECT_CALL(*proto, name()).WillRepeatedly(ReturnRef(proto->name_)); + + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + InSequence dummy; + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = true; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(filter, transportBegin(absl::optional(100))) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), + SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::I32), SetArgReferee<3>(1), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), FieldType::I32, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readInt32(_, _)).WillOnce(Return(true)); + EXPECT_CALL(filter, int32Value(_)).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); } #define TEST_NAME(X) EXPECT_EQ(ProtocolStateNameValues::name(ProtocolState::X), #X); diff --git a/test/extensions/filters/network/thrift_proxy/filter_test.cc b/test/extensions/filters/network/thrift_proxy/filter_test.cc deleted file mode 100644 index d192e47372098..0000000000000 --- a/test/extensions/filters/network/thrift_proxy/filter_test.cc +++ /dev/null @@ -1,559 +0,0 @@ -#include "common/buffer/buffer_impl.h" -#include "common/stats/stats_impl.h" - -#include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/filter.h" - -#include "test/extensions/filters/network/thrift_proxy/utility.h" -#include "test/mocks/network/mocks.h" -#include "test/test_common/printers.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::NiceMock; - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -class ThriftFilterTest : public testing::Test { -public: - ThriftFilterTest() {} - - void initializeFilter() { - for (auto counter : store_.counters()) { - counter->reset(); - } - - filter_.reset(new Filter("test.", store_)); - filter_->initializeReadFilterCallbacks(read_filter_callbacks_); - filter_->onNewConnection(); - - // NOP currently. - filter_->onAboveWriteBufferHighWatermark(); - filter_->onBelowWriteBufferLowWatermark(); - } - - void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { - uint8_t mt = static_cast(msg_type); - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, mt, // binary proto, type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0b, 0x00, 0x00, // begin string field - 0x00, 0x00, 0x00, 0x05, 'f', 'i', 'e', 'l', 'd', // string - 0x00, // stop field - }); - } - - void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, - int32_t seq_id, bool start) { - if (start) { - uint8_t mt = static_cast(msg_type); - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x2d, // framed: 45 bytes - 0x80, 0x01, 0x00, mt, // binary proto, type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0c, 0x00, 0x00, // begin struct field - 0x0b, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x05 // string length only - }); - } else { - addSeq(buffer, { - 'f', 'i', 'e', 'l', 'd', // string data - 0x0b, 0x00, 0x02, // begin string field - 0x00, 0x00, 0x00, 0x05, 'x', 'x', 'x', 'x', 'x', // string - 0x00, // stop field - 0x00, // stop field - }); - } - } - - void writeFramedBinaryTApplicationException(Buffer::Instance& buffer, int32_t seq_id) { - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x24, // framed: 36 bytes - 0x80, 0x01, 0x00, 0x03, // binary, exception - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0B, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x05, 'e', 'r', 'r', 'o', 'r', // string - 0x08, 0x00, 0x02, // begin i32 field - 0x00, 0x00, 0x00, 0x01, // exception type 1 - 0x00, // stop field - }); - } - - void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x23, // framed: 35 bytes - 0x80, 0x01, 0x00, 0x02, // binary proto, reply - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0C, 0x00, 0x02, // begin exception struct - 0x0B, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x03, 'e', 'r', 'r', // string - 0x00, // exception struct stop - 0x00, // reply struct stop field - }); - } - - Buffer::OwnedImpl buffer_; - Buffer::OwnedImpl write_buffer_; - Stats::IsolatedStoreImpl store_; - std::unique_ptr filter_; - NiceMock read_filter_callbacks_; -}; - -TEST_F(ThriftFilterTest, OnDataHandlesThriftCall) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesThriftOneWay) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(0U, store_.counter("test.request_call").value()); - EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesFrameSplitAcrossBuffers) { - initializeFilter(); - - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); - std::string expected_contents = bufferToString(buffer_); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - // Filter passes on the partial buffer, up to the last 4 bytes which it needs to resume the - // decoder on the next call. - std::string contents = bufferToString(buffer_); - EXPECT_EQ(len - 4, buffer_.length()); - EXPECT_EQ(expected_contents.substr(0, len - 4), contents); - - buffer_.drain(buffer_.length()); - - // Complete the buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); - expected_contents = expected_contents.substr(len - 4) + bufferToString(buffer_); - len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - // Filter buffered bytes from end of first buffer and passes them on now. - contents = bufferToString(buffer_); - EXPECT_EQ(len + 4, buffer_.length()); - EXPECT_EQ(expected_contents, contents); - - EXPECT_EQ(1U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesInvalidMsgType) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Reply, 0x0F); // reply is invalid for a request - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(0U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(1U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesProtocolError) { - initializeFilter(); - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(len, buffer_.length()); - - // Sniffing is now disabled. - buffer_.drain(buffer_.length()); - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(0U, store_.counter("test.request").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesProtocolErrorOnWrite) { - initializeFilter(); - - // Start the read buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - len -= buffer_.length(); - - // Disable sniffing - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); - - // Complete the read buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); - len += buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - len -= buffer_.length(); - EXPECT_EQ(0, len); -} - -TEST_F(ThriftFilterTest, OnDataStopsSniffingWithTooManyPendingCalls) { - initializeFilter(); - for (int i = 0; i < 64; i++) { - writeFramedBinaryMessage(buffer_, MessageType::Call, i); - } - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(64U, store_.gauge("test.request_active").value()); - buffer_.drain(buffer_.length()); - - // Sniffing is now disabled. - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 100); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(64U, store_.gauge("test.request_active").value()); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesThriftReply) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesOutOrOrderThriftReply) { - initializeFilter(); - - // set up two requests - writeFramedBinaryMessage(buffer_, MessageType::Call, 1); - writeFramedBinaryMessage(buffer_, MessageType::Call, 2); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(2U, store_.counter("test.request").value()); - EXPECT_EQ(2U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 2); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - write_buffer_.drain(write_buffer_.length()); - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 1); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(2U, store_.counter("test.response").value()); - EXPECT_EQ(2U, store_.counter("test.response_reply").value()); - EXPECT_EQ(2U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesFrameSplitAcrossBuffers) { - initializeFilter(); - - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F, true); - std::string expected_contents = bufferToString(write_buffer_); - uint64_t len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - // Filter passes on the partial buffer, up to the last 4 bytes which it needs to resume the - // decoder on the next call. - std::string contents = bufferToString(write_buffer_); - EXPECT_EQ(len - 4, write_buffer_.length()); - EXPECT_EQ(expected_contents.substr(0, len - 4), contents); - - write_buffer_.drain(write_buffer_.length()); - - // Complete the buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F, false); - expected_contents = expected_contents.substr(len - 4) + bufferToString(write_buffer_); - len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - // Filter buffered bytes from end of first buffer and passes them on now. - contents = bufferToString(write_buffer_); - EXPECT_EQ(len + 4, write_buffer_.length()); - EXPECT_EQ(expected_contents, contents); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesTApplicationException) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryTApplicationException(write_buffer_, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(0U, store_.counter("test.response_reply").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(1U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesIDLException) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryIDLException(write_buffer_, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(1U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesInvalidMsgType) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); // call is invalid for response - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesProtocolError) { - initializeFilter(); - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(len, buffer_.length()); - - // Sniffing is now disabled. - write_buffer_.drain(write_buffer_.length()); - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 1); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesProtocolErrorOnData) { - initializeFilter(); - - // Set up a request for the partial write - writeFramedBinaryMessage(buffer_, MessageType::Call, 1); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - buffer_.drain(buffer_.length()); - - // Start the write buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 1, true); - uint64_t len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - len -= write_buffer_.length(); - - // Force an error on the next request. - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x02, // sequence id - 0x00, // struct stop field - }); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); - - // Complete the read buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 1, false); - len += write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - len -= write_buffer_.length(); - EXPECT_EQ(0, len); -} - -TEST_F(ThriftFilterTest, OnEvent) { - // No active calls - { - initializeFilter(); - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(0U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - } - - // Close mid-request - { - initializeFilter(); - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0x01, // binary proto, call type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x0F, // seq id - }); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - } - - // Close before response - { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - } - - // Close mid-response - { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0x02, // binary proto, reply type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x0F, // seq id - }); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - write_buffer_.drain(write_buffer_.length()); - } -} - -TEST_F(ThriftFilterTest, ResponseWithUnknownSequenceID) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x10); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc index 76ed15709a2f3..2999de7edbf53 100644 --- a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc @@ -4,80 +4,80 @@ #include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" -#include "test/mocks/buffer/mocks.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" -#include "gmock/gmock.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(FramedTransportTest, Name) { - StrictMock cb; - FramedTransportImpl transport(cb); + FramedTransportImpl transport; EXPECT_EQ(transport.name(), "framed"); } +TEST(FramedTransportTest, Type) { + FramedTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Framed); +} + TEST(FramedTransportTest, NotEnoughData) { Buffer::OwnedImpl buffer; - StrictMock cb; - FramedTransportImpl transport(cb); + FramedTransportImpl transport; + absl::optional size = 1; - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(1), size); addRepeated(buffer, 3, 0); - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(1), size); } TEST(FramedTransportTest, InvalidFrameSize) { - StrictMock cb; - FramedTransportImpl transport(cb); + FramedTransportImpl transport; { Buffer::OwnedImpl buffer; addInt32(buffer, -1); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 1; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "invalid thrift framed transport frame size -1"); + EXPECT_EQ(absl::optional(1), size); } { Buffer::OwnedImpl buffer; addInt32(buffer, 0x7fffffff); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 1; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "invalid thrift framed transport frame size 2147483647"); + EXPECT_EQ(absl::optional(1), size); } } TEST(FramedTransportTest, DecodeFrameStart) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameStart(absl::optional(100U))); - - FramedTransportImpl transport(cb); + FramedTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 100); - EXPECT_EQ(buffer.length(), 4); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100U), size); EXPECT_EQ(buffer.length(), 0); } TEST(FramedTransportTest, DecodeFrameEnd) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameComplete()); - - FramedTransportImpl transport(cb); + FramedTransportImpl transport; Buffer::OwnedImpl buffer; @@ -85,9 +85,7 @@ TEST(FramedTransportTest, DecodeFrameEnd) { } TEST(FramedTransportTest, EncodeFrame) { - StrictMock cb; - - FramedTransportImpl transport(cb); + FramedTransportImpl transport; { Buffer::OwnedImpl message; diff --git a/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc b/test/extensions/filters/network/thrift_proxy/integration_test.cc similarity index 88% rename from test/extensions/filters/network/thrift_proxy/filter_integration_test.cc rename to test/extensions/filters/network/thrift_proxy/integration_test.cc index 8c76137a508c3..1e17ef45b92be 100644 --- a/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc +++ b/test/extensions/filters/network/thrift_proxy/integration_test.cc @@ -29,11 +29,11 @@ enum class CallResult { Exception, }; -class ThriftFilterIntegrationTest +class ThriftConnManagerIntegrationTest : public BaseIntegrationTest, public TestWithParam> { public: - ThriftFilterIntegrationTest() + ThriftConnManagerIntegrationTest() : BaseIntegrationTest(Network::Address::IpVersion::v4, thrift_config) {} static void SetUpTestCase() { @@ -43,10 +43,17 @@ class ThriftFilterIntegrationTest - name: envoy.filters.network.thrift_proxy config: stat_prefix: thrift_stats - - name: envoy.tcp_proxy - config: - stat_prefix: tcp_stats - cluster: cluster_0 + route_config: + name: "routes" + routes: + - match: + method_name: "execute" + route: + cluster: "cluster_0" + - match: + method_name: "poke" + route: + cluster: "cluster_0" )EOF"; } @@ -162,12 +169,12 @@ paramToString(const TestParamInfo>& p } INSTANTIATE_TEST_CASE_P( - TransportAndProtocol, ThriftFilterIntegrationTest, + TransportAndProtocol, ThriftConnManagerIntegrationTest, Combine(Values(TransportNames::get().FRAMED, TransportNames::get().UNFRAMED), Values(ProtocolNames::get().BINARY, ProtocolNames::get().COMPACT), Values(false, true)), paramToString); -TEST_P(ThriftFilterIntegrationTest, Success) { +TEST_P(ThriftConnManagerIntegrationTest, Success) { initializeCall(CallResult::Success); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -176,13 +183,12 @@ TEST_P(ThriftFilterIntegrationTest, Success) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); fake_upstream_connection->write(response_bytes_.toString()); tcp_client->waitForData(response_bytes_.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); @@ -192,7 +198,7 @@ TEST_P(ThriftFilterIntegrationTest, Success) { EXPECT_EQ(1U, counter->value()); } -TEST_P(ThriftFilterIntegrationTest, IDLException) { +TEST_P(ThriftConnManagerIntegrationTest, IDLException) { initializeCall(CallResult::IDLException); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -201,13 +207,12 @@ TEST_P(ThriftFilterIntegrationTest, IDLException) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); fake_upstream_connection->write(response_bytes_.toString()); tcp_client->waitForData(response_bytes_.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); @@ -217,7 +222,7 @@ TEST_P(ThriftFilterIntegrationTest, IDLException) { EXPECT_EQ(1U, counter->value()); } -TEST_P(ThriftFilterIntegrationTest, Exception) { +TEST_P(ThriftConnManagerIntegrationTest, Exception) { initializeCall(CallResult::Exception); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -226,13 +231,12 @@ TEST_P(ThriftFilterIntegrationTest, Exception) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); fake_upstream_connection->write(response_bytes_.toString()); tcp_client->waitForData(response_bytes_.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); @@ -242,7 +246,7 @@ TEST_P(ThriftFilterIntegrationTest, Exception) { EXPECT_EQ(1U, counter->value()); } -TEST_P(ThriftFilterIntegrationTest, Oneway) { +TEST_P(ThriftConnManagerIntegrationTest, Oneway) { initializeOneway(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -251,10 +255,9 @@ TEST_P(ThriftFilterIntegrationTest, Oneway) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_oneway"); EXPECT_EQ(1U, counter->value()); diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index b44d9b95dab57..caa93654233e8 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -2,25 +2,77 @@ #include "gtest/gtest.h" +using testing::Return; using testing::ReturnRef; +using testing::_; namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -MockTransportCallbacks::MockTransportCallbacks() {} -MockTransportCallbacks::~MockTransportCallbacks() {} +MockConfig::MockConfig() {} +MockConfig::~MockConfig() {} -MockTransport::MockTransport() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); } +MockTransport::MockTransport() { + ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); + ON_CALL(*this, type()).WillByDefault(Return(type_)); +} MockTransport::~MockTransport() {} -MockProtocolCallbacks::MockProtocolCallbacks() {} -MockProtocolCallbacks::~MockProtocolCallbacks() {} - -MockProtocol::MockProtocol() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); } +MockProtocol::MockProtocol() { + ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); + ON_CALL(*this, type()).WillByDefault(Return(type_)); +} MockProtocol::~MockProtocol() {} +MockDecoderCallbacks::MockDecoderCallbacks() {} +MockDecoderCallbacks::~MockDecoderCallbacks() {} + +namespace ThriftFilters { + +MockDecoderFilter::MockDecoderFilter() { + ON_CALL(*this, transportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, transportEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, boolValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, byteValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int16Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int32Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int64Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, doubleValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, stringValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setEnd()).WillByDefault(Return(FilterStatus::Continue)); +} +MockDecoderFilter::~MockDecoderFilter() {} + +MockDecoderFilterCallbacks::MockDecoderFilterCallbacks() { + ON_CALL(*this, streamId()).WillByDefault(Return(stream_id_)); + ON_CALL(*this, connection()).WillByDefault(Return(&connection_)); +} +MockDecoderFilterCallbacks::~MockDecoderFilterCallbacks() {} + +} // namespace ThriftFilters + +namespace Router { + +MockRouteEntry::MockRouteEntry() {} +MockRouteEntry::~MockRouteEntry() {} + +MockRoute::MockRoute() {} +MockRoute::~MockRoute() {} + +} // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 1668c42593ef9..6fae679009316 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -1,25 +1,33 @@ #pragma once +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" #include "extensions/filters/network/thrift_proxy/transport.h" +#include "test/mocks/network/mocks.h" #include "test/test_common/printers.h" #include "gmock/gmock.h" +using testing::NiceMock; + namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -class MockTransportCallbacks : public TransportCallbacks { +class MockConfig : public Config { public: - MockTransportCallbacks(); - ~MockTransportCallbacks(); - - // ThriftProxy::TransportCallbacks - MOCK_METHOD1(transportFrameStart, void(absl::optional size)); - MOCK_METHOD0(transportFrameComplete, void()); + MockConfig(); + ~MockConfig(); + + // ThriftProxy::Config + MOCK_METHOD0(filterFactory, ThriftFilters::FilterChainFactory&()); + MOCK_METHOD0(stats, ThriftFilterStats&()); + MOCK_METHOD1(createDecoder, DecoderPtr(DecoderCallbacks&)); + MOCK_METHOD0(routerConfig, Router::Config&()); }; class MockTransport : public Transport { @@ -29,24 +37,13 @@ class MockTransport : public Transport { // ThriftProxy::Transport MOCK_CONST_METHOD0(name, const std::string&()); - MOCK_METHOD1(decodeFrameStart, bool(Buffer::Instance&)); + MOCK_CONST_METHOD0(type, TransportType()); + MOCK_METHOD2(decodeFrameStart, bool(Buffer::Instance&, absl::optional&)); MOCK_METHOD1(decodeFrameEnd, bool(Buffer::Instance&)); MOCK_METHOD2(encodeFrame, void(Buffer::Instance&, Buffer::Instance&)); std::string name_{"mock"}; -}; - -class MockProtocolCallbacks : public ProtocolCallbacks { -public: - MockProtocolCallbacks(); - ~MockProtocolCallbacks(); - - // ThriftProxy::ProtocolCallbacks - MOCK_METHOD3(messageStart, void(const absl::string_view, MessageType, int32_t)); - MOCK_METHOD1(structBegin, void(const absl::string_view)); - MOCK_METHOD3(structField, void(const absl::string_view, FieldType, int16_t)); - MOCK_METHOD0(structEnd, void()); - MOCK_METHOD0(messageComplete, void()); + TransportType type_{TransportType::Auto}; }; class MockProtocol : public Protocol { @@ -56,6 +53,7 @@ class MockProtocol : public Protocol { // ThriftProxy::Protocol MOCK_CONST_METHOD0(name, const std::string&()); + MOCK_CONST_METHOD0(type, ProtocolType()); MOCK_METHOD4(readMessageBegin, bool(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id)); MOCK_METHOD1(readMessageEnd, bool(Buffer::Instance& buffer)); @@ -105,8 +103,100 @@ class MockProtocol : public Protocol { MOCK_METHOD2(writeBinary, void(Buffer::Instance& buffer, const std::string& value)); std::string name_{"mock"}; + ProtocolType type_{ProtocolType::Auto}; +}; + +class MockDecoderCallbacks : public DecoderCallbacks { +public: + MockDecoderCallbacks(); + ~MockDecoderCallbacks(); + + // ThriftProxy::DecoderCallbacks + MOCK_METHOD0(newDecoderFilter, ThriftFilters::DecoderFilter&()); +}; + +namespace ThriftFilters { + +class MockDecoderFilter : public DecoderFilter { +public: + MockDecoderFilter(); + ~MockDecoderFilter(); + + // ThriftProxy::ThriftFilters::DecoderFilter + MOCK_METHOD0(onDestroy, void()); + MOCK_METHOD1(setDecoderFilterCallbacks, void(DecoderFilterCallbacks& callbacks)); + MOCK_METHOD0(resetUpstreamConnection, void()); + MOCK_METHOD1(transportBegin, FilterStatus(absl::optional size)); + MOCK_METHOD0(transportEnd, FilterStatus()); + MOCK_METHOD3(messageBegin, + FilterStatus(const absl::string_view name, MessageType msg_type, int32_t seq_id)); + MOCK_METHOD0(messageEnd, FilterStatus()); + MOCK_METHOD1(structBegin, FilterStatus(const absl::string_view name)); + MOCK_METHOD0(structEnd, FilterStatus()); + MOCK_METHOD3(fieldBegin, + FilterStatus(const absl::string_view name, FieldType msg_type, int16_t field_id)); + MOCK_METHOD0(fieldEnd, FilterStatus()); + MOCK_METHOD1(boolValue, FilterStatus(bool value)); + MOCK_METHOD1(byteValue, FilterStatus(uint8_t value)); + MOCK_METHOD1(int16Value, FilterStatus(int16_t value)); + MOCK_METHOD1(int32Value, FilterStatus(int32_t value)); + MOCK_METHOD1(int64Value, FilterStatus(int64_t value)); + MOCK_METHOD1(doubleValue, FilterStatus(double value)); + MOCK_METHOD1(stringValue, FilterStatus(absl::string_view value)); + MOCK_METHOD3(mapBegin, FilterStatus(FieldType key_type, FieldType value_type, uint32_t size)); + MOCK_METHOD0(mapEnd, FilterStatus()); + MOCK_METHOD2(listBegin, FilterStatus(FieldType elem_type, uint32_t size)); + MOCK_METHOD0(listEnd, FilterStatus()); + MOCK_METHOD2(setBegin, FilterStatus(FieldType elem_type, uint32_t size)); + MOCK_METHOD0(setEnd, FilterStatus()); +}; + +class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { +public: + MockDecoderFilterCallbacks(); + ~MockDecoderFilterCallbacks(); + + // ThriftProxy::ThriftFilters::DecoderFilterCallbacks + MOCK_METHOD0(streamId, uint64_t()); + MOCK_METHOD0(connection, const Network::Connection*()); + MOCK_METHOD0(continueDecoding, void()); + MOCK_METHOD0(route, Router::RouteConstSharedPtr()); + MOCK_METHOD0(downstreamTransportType, TransportType()); + MOCK_METHOD0(downstreamProtocolType, ProtocolType()); + void sendLocalReply(DirectResponsePtr&& response) override { sendLocalReply_(response); } + MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); + MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); + MOCK_METHOD0(resetDownstreamConnection, void()); + + MOCK_METHOD1(sendLocalReply_, void(DirectResponsePtr&)); + + uint64_t stream_id_{1}; + NiceMock connection_; +}; + +} // namespace ThriftFilters + +namespace Router { + +class MockRouteEntry : public RouteEntry { +public: + MockRouteEntry(); + ~MockRouteEntry(); + + // ThriftProxy::Router::RouteEntry + MOCK_CONST_METHOD0(clusterName, const std::string&()); +}; + +class MockRoute : public Route { +public: + MockRoute(); + ~MockRoute(); + + // ThriftProxy::Router::Route + MOCK_CONST_METHOD0(routeEntry, const RouteEntry*()); }; +} // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc index 58420a0492622..7a8fef74a1490 100644 --- a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc @@ -24,10 +24,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +TEST(ProtocolNames, FromType) { + for (int i = 0; i <= static_cast(ProtocolType::LastProtocolType); i++) { + ProtocolType type = static_cast(i); + EXPECT_NE("", ProtocolNames::get().fromType(type)); + } +} + TEST(AutoProtocolTest, NotEnoughData) { Buffer::OwnedImpl buffer; - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -41,8 +47,7 @@ TEST(AutoProtocolTest, NotEnoughData) { TEST(AutoProtocolTest, UnknownProtocol) { Buffer::OwnedImpl buffer; - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -59,8 +64,7 @@ TEST(AutoProtocolTest, UnknownProtocol) { TEST(AutoProtocolTest, ReadMessageBegin) { // Binary Protocol { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -79,12 +83,12 @@ TEST(AutoProtocolTest, ReadMessageBegin) { EXPECT_EQ(seq_id, 1); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "binary(auto)"); + EXPECT_EQ(proto.type(), ProtocolType::Binary); } // Compact protocol { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = 1; @@ -101,13 +105,13 @@ TEST(AutoProtocolTest, ReadMessageBegin) { EXPECT_EQ(seq_id, 0x0102); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "compact(auto)"); + EXPECT_EQ(proto.type(), ProtocolType::Compact); } } TEST(AutoProtocolTest, ReadDelegation) { NiceMock* proto = new NiceMock(); - NiceMock dummy_cb; - AutoProtocolImpl auto_proto(dummy_cb); + AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // readMessageBegin @@ -232,8 +236,7 @@ TEST(AutoProtocolTest, ReadDelegation) { TEST(AutoProtocolTest, WriteDelegation) { NiceMock* proto = new NiceMock(); - NiceMock dummy_cb; - AutoProtocolImpl auto_proto(dummy_cb); + AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // writeMessageBegin @@ -319,11 +322,15 @@ TEST(AutoProtocolTest, WriteDelegation) { } TEST(AutoProtocolTest, Name) { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; EXPECT_EQ(proto.name(), "auto"); } +TEST(AutoProtocolTest, Type) { + AutoProtocolImpl proto; + EXPECT_EQ(proto.type(), ProtocolType::Auto); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc new file mode 100644 index 0000000000000..2ccae47928318 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -0,0 +1,596 @@ +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.pb.h" +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.pb.validate.h" +#include "envoy/tcp/conn_pool.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/router/config.h" +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" +#include "test/test_common/registry.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ContainsRegex; +using testing::Invoke; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; +using testing::Test; +using testing::TestWithParam; +using testing::Values; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +namespace { + +envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration +parseRouteConfigurationFromV2Yaml(const std::string& yaml) { + envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration route_config; + MessageUtil::loadFromYaml(yaml, route_config); + MessageUtil::validate(route_config); + return route_config; +} + +class TestNamedTransportConfigFactory : public NamedTransportConfigFactory { +public: + TestNamedTransportConfigFactory(std::function f) : f_(f) {} + + TransportPtr createTransport() override { return TransportPtr{f_()}; } + std::string name() override { return TransportNames::get().FRAMED; } + + std::function f_; +}; + +class TestNamedProtocolConfigFactory : public NamedProtocolConfigFactory { +public: + TestNamedProtocolConfigFactory(std::function f) : f_(f) {} + + ProtocolPtr createProtocol() override { return ProtocolPtr{f_()}; } + std::string name() override { return ProtocolNames::get().BINARY; } + + std::function f_; +}; + +} // namespace + +class ThriftRouterTestBase { +public: + ThriftRouterTestBase() + : transport_factory_([&]() -> MockTransport* { return transport_; }), + protocol_factory_([&]() -> MockProtocol* { return protocol_; }), + transport_register_(transport_factory_), protocol_register_(protocol_factory_) {} + + void initializeRouter() { + route_ = new NiceMock(); + route_ptr_.reset(route_); + + host_ = new NiceMock(); + host_ptr_.reset(host_); + + router_.reset(new Router(context_.clusterManager())); + + EXPECT_EQ(nullptr, router_->downstreamConnection()); + + router_->setDecoderFilterCallbacks(callbacks_); + } + + void startRequest(MessageType msg_type) { + msg_type_ = msg_type; + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportBegin({})); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + + EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* { + conn_pool_callbacks_ = &cb; + return &handle_; + })); + + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, msg_type_, seq_id_)); + EXPECT_NE(nullptr, conn_pool_callbacks_); + + NiceMock connection; + EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection)); + EXPECT_EQ(&connection, router_->downstreamConnection()); + + // Not yet implemented: + EXPECT_EQ(absl::optional(), router_->computeHashKey()); + EXPECT_EQ(nullptr, router_->metadataMatchCriteria()); + EXPECT_EQ(nullptr, router_->downstreamHeaders()); + } + + void connectUpstream() { + EXPECT_CALL(conn_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { + upstream_callbacks_ = &cb; + })); + + EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); + transport_ = new NiceMock(); + ON_CALL(*transport_, type()).WillByDefault(Return(TransportType::Framed)); + + EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); + protocol_ = new NiceMock(); + ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); + EXPECT_CALL(*protocol_, writeMessageBegin(_, method_name_, msg_type_, seq_id_)); + + EXPECT_CALL(callbacks_, continueDecoding()); + conn_pool_callbacks_->onPoolReady(conn_data_, host_ptr_); + EXPECT_NE(nullptr, upstream_callbacks_); + } + + void sendTrivialStruct(FieldType field_type) { + EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({})); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, 1)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldBegin({}, field_type, 1)); + + sendTrivialValue(field_type); + + EXPECT_CALL(*protocol_, writeFieldEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldEnd()); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", FieldType::Stop, 0)); + EXPECT_CALL(*protocol_, writeStructEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structEnd()); + } + + void sendTrivialValue(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + EXPECT_CALL(*protocol_, writeBool(_, true)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->boolValue(true)); + break; + case FieldType::Byte: + EXPECT_CALL(*protocol_, writeByte(_, 2)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->byteValue(2)); + break; + case FieldType::I16: + EXPECT_CALL(*protocol_, writeInt16(_, 3)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int16Value(3)); + break; + case FieldType::I32: + EXPECT_CALL(*protocol_, writeInt32(_, 4)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(4)); + break; + case FieldType::I64: + EXPECT_CALL(*protocol_, writeInt64(_, 5)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int64Value(5)); + break; + case FieldType::Double: + EXPECT_CALL(*protocol_, writeDouble(_, 6.0)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->doubleValue(6.0)); + break; + case FieldType::String: + EXPECT_CALL(*protocol_, writeString(_, "seven")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->stringValue("seven")); + break; + default: + NOT_REACHED; + } + } + + void completeRequest() { + EXPECT_CALL(*protocol_, writeMessageEnd(_)); + EXPECT_CALL(*transport_, encodeFrame(_, _)); + EXPECT_CALL(conn_data_.connection_, write(_, false)); + + if (msg_type_ == MessageType::Oneway) { + EXPECT_CALL(conn_data_, release()); + } + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->messageEnd()); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportEnd()); + } + + void returnResponse() { + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); + upstream_callbacks_->onUpstreamData(buffer, false); + + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(conn_data_, release()); + upstream_callbacks_->onUpstreamData(buffer, false); + } + + void destroyRouter() { + router_->onDestroy(); + router_.reset(); + } + + TestNamedTransportConfigFactory transport_factory_; + TestNamedProtocolConfigFactory protocol_factory_; + Registry::InjectFactory transport_register_; + Registry::InjectFactory protocol_register_; + + NiceMock context_; + NiceMock callbacks_; + NiceMock* transport_{}; + NiceMock* protocol_{}; + NiceMock* route_{}; + NiceMock route_entry_; + NiceMock* host_{}; + + RouteConstSharedPtr route_ptr_; + Upstream::HostDescriptionConstSharedPtr host_ptr_; + + std::unique_ptr router_; + + std::string cluster_name_{"cluster"}; + + std::string method_name_{"method"}; + MessageType msg_type_{MessageType::Call}; + int32_t seq_id_{1}; + + NiceMock handle_; + NiceMock conn_data_; + Tcp::ConnectionPool::Callbacks* conn_pool_callbacks_{}; + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks_{}; +}; + +class ThriftRouterTest : public ThriftRouterTestBase, public Test { +public: + ThriftRouterTest() {} +}; + +class ThriftRouterFieldTypeTest : public ThriftRouterTestBase, public TestWithParam { +public: + ThriftRouterFieldTypeTest() {} +}; + +INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, ThriftRouterFieldTypeTest, + Values(FieldType::Bool, FieldType::Byte, FieldType::I16, FieldType::I32, + FieldType::I64, FieldType::Double, FieldType::String), + fieldTypeParamToString); + +class ThriftRouterContainerTest : public ThriftRouterTestBase, public TestWithParam { +public: + ThriftRouterContainerTest() {} +}; + +INSTANTIATE_TEST_CASE_P(ContainerFieldTypes, ThriftRouterContainerTest, + Values(FieldType::Map, FieldType::List, FieldType::Set), + fieldTypeParamToString); + +TEST_F(ThriftRouterTest, PoolConnectionFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::ConnectionFailure, + host_ptr_); +} + +TEST_F(ThriftRouterTest, PoolOverflowFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*too many connections.*")); + })); + conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::Overflow, host_ptr_); +} + +TEST_F(ThriftRouterTest, NoRoute) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + if (app_ex != nullptr) { + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::UnknownMethod, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no route.*")); + } + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, NoCluster) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(context_.cluster_manager_, get(cluster_name_)).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*unknown cluster.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, ClusterMaintenanceMode) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(*context_.cluster_manager_.thread_local_cluster_.cluster_.info_, maintenanceMode()) + .WillOnce(Return(true)); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*maintenance mode.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, NoHealthyHosts) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(context_.cluster_manager_, tcpConnPoolForCluster(cluster_name_, _, _)) + .WillOnce(Return(nullptr)); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no healthy upstream.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, TruncatedResponse) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + completeRequest(); + + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(conn_data_, release()); + EXPECT_CALL(callbacks_, resetDownstreamConnection()); + + upstream_callbacks_->onUpstreamData(buffer, true); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + completeRequest(); + + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))) + .WillOnce(Invoke([&](Buffer::Instance&) -> bool { + router_->resetUpstreamConnection(); + return true; + })); + EXPECT_CALL(conn_data_.connection_, close(Network::ConnectionCloseType::NoFlush)); + + upstream_callbacks_->onUpstreamData(buffer, true); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UnexpectedRouterDestroyBeforeUpstreamConnect) { + initializeRouter(); + startRequest(MessageType::Call); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UnexpectedRouterDestroy) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + EXPECT_CALL(conn_data_.connection_, close(Network::ConnectionCloseType::NoFlush)); + destroyRouter(); +} + +TEST_P(ThriftRouterFieldTypeTest, OneWay) { + FieldType field_type = GetParam(); + + initializeRouter(); + startRequest(MessageType::Oneway); + connectUpstream(); + sendTrivialStruct(field_type); + completeRequest(); + destroyRouter(); +} + +TEST_P(ThriftRouterFieldTypeTest, Call) { + FieldType field_type = GetParam(); + + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(field_type); + completeRequest(); + returnResponse(); + destroyRouter(); +} + +TEST_P(ThriftRouterContainerTest, DecoderFilterCallbacks) { + FieldType field_type = GetParam(); + + initializeRouter(); + + startRequest(MessageType::Oneway); + connectUpstream(); + + EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({})); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, 1)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldBegin({}, field_type, 1)); + + switch (field_type) { + case FieldType::Map: + EXPECT_CALL(*protocol_, writeMapBegin(_, FieldType::I32, FieldType::I32, 2)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, + router_->mapBegin(FieldType::I32, FieldType::I32, 2)); + for (int i = 0; i < 2; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + EXPECT_CALL(*protocol_, writeInt32(_, i + 100)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i + 100)); + } + EXPECT_CALL(*protocol_, writeMapEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->mapEnd()); + break; + case FieldType::List: + EXPECT_CALL(*protocol_, writeListBegin(_, FieldType::I32, 3)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->listBegin(FieldType::I32, 3)); + for (int i = 0; i < 3; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + } + EXPECT_CALL(*protocol_, writeListEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->listEnd()); + break; + case FieldType::Set: + EXPECT_CALL(*protocol_, writeSetBegin(_, FieldType::I32, 4)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setBegin(FieldType::I32, 4)); + for (int i = 0; i < 4; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + } + EXPECT_CALL(*protocol_, writeSetEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setEnd()); + break; + default: + NOT_REACHED; + } + + EXPECT_CALL(*protocol_, writeFieldEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldEnd()); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, _, FieldType::Stop, 0)); + EXPECT_CALL(*protocol_, writeStructEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structEnd()); + + completeRequest(); + destroyRouter(); +} + +TEST(RouteMatcherTest, Route) { + const std::string yaml = R"EOF( +name: config +routes: + - match: + method: "method1" + route: + cluster: "cluster1" + - match: + method: "method2" + route: + cluster: "cluster2" +)EOF"; + + envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration config = + parseRouteConfigurationFromV2Yaml(yaml); + + RouteMatcher matcher(config); + EXPECT_EQ(nullptr, matcher.route("unknown")); + EXPECT_EQ(nullptr, matcher.route("METHOD1")); + + RouteConstSharedPtr route = matcher.route("method1"); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + RouteConstSharedPtr route2 = matcher.route("method2"); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); +} + +TEST(RouteMatcherTest, RouteMatchAny) { + const std::string yaml = R"EOF( +name: config +routes: + - match: + method: "method1" + route: + cluster: "cluster1" + - match: {} + route: + cluster: "cluster2" +)EOF"; + + envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration config = + parseRouteConfigurationFromV2Yaml(yaml); + + RouteMatcher matcher(config); + RouteConstSharedPtr route = matcher.route("method1"); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + RouteConstSharedPtr route2 = matcher.route("anything"); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc index ea6d4e1e1656f..64ab3cce319be 100644 --- a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc @@ -14,28 +14,35 @@ using testing::NiceMock; using testing::Ref; -using testing::StrictMock; namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +TEST(TransportNames, FromType) { + for (int i = 0; i <= static_cast(TransportType::LastTransportType); i++) { + TransportType type = static_cast(i); + EXPECT_NE("", TransportNames::get().fromType(type)); + } +} + TEST(AutoTransportTest, NotEnoughData) { Buffer::OwnedImpl buffer; - StrictMock cb; - AutoTransportImpl transport(cb); + AutoTransportImpl transport; + absl::optional size = 100; - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100), size); addRepeated(buffer, 7, 0); - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100), size); } TEST(AutoTransportTest, UnknownTransport) { - StrictMock cb; - AutoTransportImpl transport(cb); + AutoTransportImpl transport; // Looks like unframed, but fails protocol check. { @@ -43,8 +50,10 @@ TEST(AutoTransportTest, UnknownTransport) { addInt32(buffer, 0); addInt32(buffer, 0); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 100; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "unknown thrift auto transport frame start 00 00 00 00 00 00 00 00"); + EXPECT_EQ(absl::optional(100), size); } // Looks like framed, but fails protocol check. @@ -53,91 +62,95 @@ TEST(AutoTransportTest, UnknownTransport) { addInt32(buffer, 0xFF); addInt32(buffer, 0); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 100; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "unknown thrift auto transport frame start 00 00 00 ff 00 00 00 00"); + EXPECT_EQ(absl::optional(100), size); } } TEST(AutoTransportTest, DecodeFrameStart) { - StrictMock cb; - // Framed transport + binary protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xFF); addInt16(buffer, 0x8001); addInt16(buffer, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional(255U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(255), size); EXPECT_EQ(transport.name(), "framed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Framed); EXPECT_EQ(buffer.length(), 4); } // Framed transport + compact protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xFFF); addInt16(buffer, 0x8201); addInt16(buffer, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional(4095U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(4095), size); EXPECT_EQ(transport.name(), "framed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Framed); EXPECT_EQ(buffer.length(), 4); } // Unframed transport + binary protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt16(buffer, 0x8001); addRepeated(buffer, 6, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); EXPECT_EQ(transport.name(), "unframed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); } // Unframed transport + compact protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt16(buffer, 0x8201); addRepeated(buffer, 6, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); EXPECT_EQ(transport.name(), "unframed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); } } TEST(AutoTransportTest, DecodeFrameEnd) { - StrictMock cb; - - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xFF); addInt16(buffer, 0x8001); addInt16(buffer, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional(255U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); EXPECT_EQ(buffer.length(), 4); - EXPECT_CALL(cb, transportFrameComplete()); EXPECT_TRUE(transport.decodeFrameEnd(buffer)); } TEST(AutoTransportTest, EncodeFrame) { - StrictMock cb; MockTransport* mock_transport = new NiceMock(); - AutoTransportImpl transport(cb); + AutoTransportImpl transport; transport.setTransport(TransportPtr{mock_transport}); Buffer::OwnedImpl buffer; @@ -148,11 +161,15 @@ TEST(AutoTransportTest, EncodeFrame) { } TEST(AutoTransportTest, Name) { - StrictMock cb; - AutoTransportImpl transport(cb); + AutoTransportImpl transport; EXPECT_EQ(transport.name(), "auto"); } +TEST(AutoTransportTest, Type) { + AutoTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Auto); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc index f26a75219ab61..f83119ffaf383 100644 --- a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc @@ -2,55 +2,49 @@ #include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" -#include "gmock/gmock.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(UnframedTransportTest, Name) { - StrictMock cb; - UnframedTransportImpl transport(cb); + UnframedTransportImpl transport; EXPECT_EQ(transport.name(), "unframed"); } -TEST(UnframedTransportTest, DecodeFrameStart) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameStart(absl::optional())); +TEST(UnframedTransportTest, Type) { + UnframedTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Unframed); +} - UnframedTransportImpl transport(cb); +TEST(UnframedTransportTest, DecodeFrameStart) { + UnframedTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xDEADBEEF); - EXPECT_EQ(buffer.length(), 4); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); EXPECT_EQ(buffer.length(), 4); } TEST(UnframedTransportTest, DecodeFrameEnd) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameComplete()); - - UnframedTransportImpl transport(cb); + UnframedTransportImpl transport; Buffer::OwnedImpl buffer; EXPECT_TRUE(transport.decodeFrameEnd(buffer)); } TEST(UnframedTransportTest, EncodeFrame) { - StrictMock cb; - - UnframedTransportImpl transport(cb); + UnframedTransportImpl transport; Buffer::OwnedImpl message; message.add("fake message"); diff --git a/test/mocks/tcp/BUILD b/test/mocks/tcp/BUILD index 0988722b71d7b..8634b86e9c5c8 100644 --- a/test/mocks/tcp/BUILD +++ b/test/mocks/tcp/BUILD @@ -15,6 +15,7 @@ envoy_cc_mock( deps = [ "//include/envoy/buffer:buffer_interface", "//include/envoy/tcp:conn_pool_interface", + "//test/mocks/network:network_mocks", "//test/mocks/upstream:host_mocks", ], ) diff --git a/test/mocks/tcp/mocks.cc b/test/mocks/tcp/mocks.cc index f8586de4f0415..13c92228d07d4 100644 --- a/test/mocks/tcp/mocks.cc +++ b/test/mocks/tcp/mocks.cc @@ -1,7 +1,8 @@ #include "mocks.h" #include "gmock/gmock.h" -#include "gtest/gtest.h" + +using testing::ReturnRef; namespace Envoy { namespace Tcp { @@ -13,6 +14,11 @@ MockCancellable::~MockCancellable() {} MockUpstreamCallbacks::MockUpstreamCallbacks() {} MockUpstreamCallbacks::~MockUpstreamCallbacks() {} +MockConnectionData::MockConnectionData() { + ON_CALL(*this, connection()).WillByDefault(ReturnRef(connection_)); +} +MockConnectionData::~MockConnectionData() {} + MockInstance::MockInstance() {} MockInstance::~MockInstance() {} diff --git a/test/mocks/tcp/mocks.h b/test/mocks/tcp/mocks.h index ed90509a442cc..4f8be71d8c1bd 100644 --- a/test/mocks/tcp/mocks.h +++ b/test/mocks/tcp/mocks.h @@ -3,11 +3,14 @@ #include "envoy/tcp/conn_pool.h" #include "test/mocks/common.h" +#include "test/mocks/network/mocks.h" #include "test/mocks/upstream/host.h" #include "test/test_common/printers.h" #include "gmock/gmock.h" +using testing::NiceMock; + namespace Envoy { namespace Tcp { namespace ConnectionPool { @@ -30,6 +33,19 @@ class MockUpstreamCallbacks : public UpstreamCallbacks { MOCK_METHOD2(onUpstreamData, void(Buffer::Instance& data, bool end_stream)); }; +class MockConnectionData : public ConnectionData { +public: + MockConnectionData(); + ~MockConnectionData(); + + // Tcp::ConnectionPool::ConnectionData + MOCK_METHOD0(connection, Network::ClientConnection&()); + MOCK_METHOD1(addUpstreamCallbacks, void(ConnectionPool::UpstreamCallbacks&)); + MOCK_METHOD0(release, void()); + + NiceMock connection_; +}; + class MockInstance : public Instance { public: MockInstance(); From 3ae3f17e0d3cbd039e9022d4f2acd298d66858cf Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Mon, 16 Jul 2018 09:14:28 -0700 Subject: [PATCH 2/9] fix asan Signed-off-by: Stephan Zuercher --- .../filters/network/thrift_proxy/decoder.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index 815a379edc9fd..0e0bf253ff216 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -219,42 +219,42 @@ DecoderStateMachine::DecoderStatus DecoderStateMachine::handleValue(Buffer::Inst ProtocolState return_state) { switch (elem_type) { case FieldType::Bool: { - bool value; + bool value{}; if (proto_.readBool(buffer, value)) { return DecoderStatus(return_state, filter_.boolValue(value)); } break; } case FieldType::Byte: { - uint8_t value; + uint8_t value{}; if (proto_.readByte(buffer, value)) { return DecoderStatus(return_state, filter_.byteValue(value)); } break; } case FieldType::I16: { - int16_t value; + int16_t value{}; if (proto_.readInt16(buffer, value)) { return DecoderStatus(return_state, filter_.int16Value(value)); } break; } case FieldType::I32: { - int32_t value; + int32_t value{}; if (proto_.readInt32(buffer, value)) { return DecoderStatus(return_state, filter_.int32Value(value)); } break; } case FieldType::I64: { - int64_t value; + int64_t value{}; if (proto_.readInt64(buffer, value)) { return DecoderStatus(return_state, filter_.int64Value(value)); } break; } case FieldType::Double: { - double value; + double value{}; if (proto_.readDouble(buffer, value)) { return DecoderStatus(return_state, filter_.doubleValue(value)); } From 356e103049e3beee819b586a3684f90b221b5dea Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Mon, 16 Jul 2018 09:33:18 -0700 Subject: [PATCH 3/9] convert unreachable exception into release assert Signed-off-by: Stephan Zuercher --- source/extensions/filters/network/thrift_proxy/config.cc | 9 ++------- source/extensions/filters/network/thrift_proxy/config.h | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index aa7201f483ec6..ffab3d37f53c9 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -117,19 +117,14 @@ DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { TransportPtr ConfigImpl::createTransport() { TransportTypeMap::const_iterator i = transportTypeMap().find(transport_); - if (i == transportTypeMap().end()) { - throw new EnvoyException(fmt::format("unknown transport type: {}", transport_)); - } + RELEASE_ASSERT(i != transportTypeMap().end(), "invalid transport type"); return NamedTransportConfigFactory::getFactory(i->second).createTransport(); } ProtocolPtr ConfigImpl::createProtocol() { ProtocolTypeMap::const_iterator i = protocolTypeMap().find(proto_); - if (i == protocolTypeMap().end()) { - throw new EnvoyException(fmt::format("unknown protocol type: {}", proto_)); - } - + RELEASE_ASSERT(i != protocolTypeMap().end(), "invalid protocol type"); return NamedProtocolConfigFactory::getFactory(i->second).createProtocol(); } diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 265f80f57ab23..e9b9369572b2c 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -49,7 +49,7 @@ class ConfigImpl : public Config, return route_matcher_->route(method_name); } - // Config: + // Config ThriftFilterStats& stats() override { return stats_; } ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; From add4328b28ed0321ac59454438ad402bf542f01a Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Wed, 18 Jul 2018 10:00:57 -0700 Subject: [PATCH 4/9] match master changes Signed-off-by: Stephan Zuercher --- .../extensions/filters/network/thrift_proxy/conn_manager.h | 6 +++--- source/extensions/filters/network/thrift_proxy/protocol.h | 2 +- .../filters/network/thrift_proxy/protocol_converter.h | 6 +++--- .../filters/network/thrift_proxy/router/router_impl.cc | 2 +- source/extensions/filters/network/thrift_proxy/transport.h | 2 +- test/extensions/filters/network/thrift_proxy/router_test.cc | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index 82eb20feab176..15313a61ba66c 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -135,11 +135,11 @@ class ConnectionManager : public Network::ReadFilter, } // ThriftFilters::DecoderFilter - void onDestroy() override { NOT_IMPLEMENTED; } + void onDestroy() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void resetUpstreamConnection() override { NOT_IMPLEMENTED; } + void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } ThriftFilters::FilterStatus transportBegin(absl::optional size) override { return decoder_filter_->transportBegin(size); } diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 37218b5586637..02f2808427e08 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -56,7 +56,7 @@ class ProtocolNameValues { case ProtocolType::Auto: return AUTO; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } }; diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index 87380eb5f3b89..af7b6dfa2af3d 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -26,11 +26,11 @@ class ProtocolConverter : public ThriftFilters::DecoderFilter { } // ThiftFilters::DecoderFilter - void onDestroy() override { NOT_IMPLEMENTED; } + void onDestroy() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void resetUpstreamConnection() override { NOT_IMPLEMENTED; } + void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, int32_t seq_id) override { proto_->writeMessageBegin(*buffer_, std::string(name), msg_type, seq_id); diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index eafd626af1b6f..e3165c462d2bd 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -269,7 +269,7 @@ void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReas parent_.callbacks_->resetDownstreamConnection(); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } diff --git a/source/extensions/filters/network/thrift_proxy/transport.h b/source/extensions/filters/network/thrift_proxy/transport.h index 776f3a91d8d8b..1bda083e6bc3e 100644 --- a/source/extensions/filters/network/thrift_proxy/transport.h +++ b/source/extensions/filters/network/thrift_proxy/transport.h @@ -50,7 +50,7 @@ class TransportNameValues { case TransportType::Auto: return AUTO; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } }; diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 2ccae47928318..0cbd27b204491 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -187,7 +187,7 @@ class ThriftRouterTestBase { EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->stringValue("seven")); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -519,7 +519,7 @@ TEST_P(ThriftRouterContainerTest, DecoderFilterCallbacks) { EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setEnd()); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } EXPECT_CALL(*protocol_, writeFieldEnd(_)); From 21feb7544d08199e43f06052c67e98ca8fdf759a Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Wed, 18 Jul 2018 10:38:05 -0700 Subject: [PATCH 5/9] fix connection destroy logging; reset upstreams on conn close Signed-off-by: Stephan Zuercher --- .../network/thrift_proxy/v2alpha1/route.proto | 1 - .../network/thrift_proxy/conn_manager.cc | 21 ++++---- .../network/thrift_proxy/conn_manager_test.cc | 49 ++++++++++++++++--- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto index 1d78b99647da9..5c9af1c48755a 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto @@ -25,7 +25,6 @@ message Route { // Route request to some upstream cluster. RouteAction route = 2 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; - ; } // [#comment:next free field: 2] diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index ff066564ce620..7f632af8425c6 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -51,6 +51,7 @@ void ConnectionManager::dispatch() { rpcs_.front()->onError(ex.what()); resetAllRpcs(); + read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); } } @@ -67,8 +68,6 @@ void ConnectionManager::resetAllRpcs() { while (!rpcs_.empty()) { rpcs_.front()->onReset(); } - - read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); } void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) { @@ -76,16 +75,14 @@ void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbac } void ConnectionManager::onEvent(Network::ConnectionEvent event) { - if (rpcs_.empty()) { - return; - } - - if (event == Network::ConnectionEvent::RemoteClose) { - stats_.cx_destroy_local_with_active_rq_.inc(); - } + if (!rpcs_.empty()) { + if (event == Network::ConnectionEvent::RemoteClose) { + stats_.cx_destroy_remote_with_active_rq_.inc(); + } else if (event == Network::ConnectionEvent::LocalClose) { + stats_.cx_destroy_local_with_active_rq_.inc(); + } - if (event == Network::ConnectionEvent::LocalClose) { - stats_.cx_destroy_remote_with_active_rq_.inc(); + resetAllRpcs(); } } @@ -202,7 +199,7 @@ void ConnectionManager::ActiveRpc::createFilterChain() { } void ConnectionManager::ActiveRpc::onReset() { - // TODO(zuercer): e.g., parent_.stats_.named_.downstream_rq_rx_reset_.inc(); + // TODO(zuercher): e.g., parent_.stats_.named_.downstream_rq_rx_reset_.inc(); parent_.doDeferredRpcDestroy(*this); } diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index f3cf9ff8444c7..48dce17b7d40e 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -413,7 +413,7 @@ TEST_F(ThriftConnectionManagerTest, OnEvent) { EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); } - // Close mid-request + // Remote close mid-request { initializeFilter(); addSeq(buffer_, { @@ -424,28 +424,65 @@ TEST_F(ThriftConnectionManagerTest, OnEvent) { }); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - filter_->onEvent(Network::ConnectionEvent::LocalClose); EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Local close mid-request + { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0x01, // binary proto, call type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x0F, // seq id + }); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); } - // Close before response + // Remote close before response { initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - filter_->onEvent(Network::ConnectionEvent::LocalClose); EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Local close before response + { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); } } From bacbdedcff34172fcc68d285ebec29cd49a2f4f6 Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Thu, 19 Jul 2018 09:26:09 -0700 Subject: [PATCH 6/9] add const tags to DecoderFilterCallbacks Signed-off-by: Stephan Zuercher --- .../filters/network/thrift_proxy/conn_manager.cc | 2 +- .../filters/network/thrift_proxy/conn_manager.h | 12 ++++++++---- .../filters/network/thrift_proxy/filters/filter.h | 8 ++++---- test/extensions/filters/network/thrift_proxy/mocks.h | 8 ++++---- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index 7f632af8425c6..c94bbeefcec36 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -215,7 +215,7 @@ void ConnectionManager::ActiveRpc::onError(const std::string& what) { // possible to provide a valid response, so don't try. } -const Network::Connection* ConnectionManager::ActiveRpc::connection() { +const Network::Connection* ConnectionManager::ActiveRpc::connection() const { return &parent_.read_callbacks_->connection(); } diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index 15313a61ba66c..c366a40c0f2a5 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -195,12 +195,16 @@ class ConnectionManager : public Network::ReadFilter, ThriftFilters::FilterStatus setEnd() override { return decoder_filter_->setEnd(); } // ThriftFilters::DecoderFilterCallbacks - uint64_t streamId() override { return stream_id_; } - const Network::Connection* connection() override; + uint64_t streamId() const override { return stream_id_; } + const Network::Connection* connection() const override; void continueDecoding() override; Router::RouteConstSharedPtr route() override; - TransportType downstreamTransportType() override { return parent_.decoder_->transportType(); } - ProtocolType downstreamProtocolType() override { return parent_.decoder_->protocolType(); } + TransportType downstreamTransportType() const override { + return parent_.decoder_->transportType(); + } + ProtocolType downstreamProtocolType() const override { + return parent_.decoder_->protocolType(); + } void sendLocalReply(ThriftFilters::DirectResponsePtr&& response) override; void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; bool upstreamData(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index d677eaa0711eb..969ffcadfc46c 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -44,12 +44,12 @@ class DecoderFilterCallbacks { /** * @return uint64_t the ID of the originating stream for logging purposes. */ - virtual uint64_t streamId() PURE; + virtual uint64_t streamId() const PURE; /** * @return const Network::Connection* the originating connection, or nullptr if there is none. */ - virtual const Network::Connection* connection() PURE; + virtual const Network::Connection* connection() const PURE; /** * Continue iterating through the filter chain with buffered data. This routine can only be @@ -68,12 +68,12 @@ class DecoderFilterCallbacks { /** * @return TransportType the originating transport. */ - virtual TransportType downstreamTransportType() PURE; + virtual TransportType downstreamTransportType() const PURE; /** * @return ProtocolType the originating protocol. */ - virtual ProtocolType downstreamProtocolType() PURE; + virtual ProtocolType downstreamProtocolType() const PURE; /** * Create a locally generated response using the provided response object. diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 6fae679009316..f932bc808d418 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -157,12 +157,12 @@ class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { ~MockDecoderFilterCallbacks(); // ThriftProxy::ThriftFilters::DecoderFilterCallbacks - MOCK_METHOD0(streamId, uint64_t()); - MOCK_METHOD0(connection, const Network::Connection*()); + MOCK_CONST_METHOD0(streamId, uint64_t()); + MOCK_CONST_METHOD0(connection, const Network::Connection*()); MOCK_METHOD0(continueDecoding, void()); MOCK_METHOD0(route, Router::RouteConstSharedPtr()); - MOCK_METHOD0(downstreamTransportType, TransportType()); - MOCK_METHOD0(downstreamProtocolType, ProtocolType()); + MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); + MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); void sendLocalReply(DirectResponsePtr&& response) override { sendLocalReply_(response); } MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); From d065987661b1c6913d53da86ee3a77ee1b76fa51 Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Thu, 19 Jul 2018 13:26:36 -0700 Subject: [PATCH 7/9] bump CI Signed-off-by: Stephan Zuercher From 69b02a31c524a141fd06239ab0751490048637af Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Thu, 19 Jul 2018 13:37:18 -0700 Subject: [PATCH 8/9] rm stray ws Signed-off-by: Stephan Zuercher --- source/extensions/extensions_build_config.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index 225c897c12df2..3adfeb6836c9e 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -90,7 +90,7 @@ EXTENSIONS = { # "envoy.filters.thrift.router": "//source/extensions/filters/network/thrift_proxy/router:config", - + # # Tracers # From e4c839e64b65122fc1739bacf66d533c1bd79a93 Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Mon, 23 Jul 2018 15:18:25 -0700 Subject: [PATCH 9/9] match merge changes Signed-off-by: Stephan Zuercher --- .../thrift_proxy/router/router_impl.cc | 29 ++++++++++++- .../network/thrift_proxy/router/router_impl.h | 3 ++ .../network/thrift_proxy/router_test.cc | 41 +++++++++++++++++-- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index e3165c462d2bd..057d529f13981 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -164,11 +164,33 @@ void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { if (end_stream) { // Response is incomplete, but no more data is coming. upstream_request_->onResponseComplete(); - upstream_request_->onResetStream(Tcp::ConnectionPool::PoolFailureReason::ConnectionFailure); + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); cleanup(); } } +void Router::onEvent(Network::ConnectionEvent event) { + if (!upstream_request_ || upstream_request_->response_complete_) { + // Client closed connection after completing response. + return; + } + + switch (event) { + case Network::ConnectionEvent::RemoteClose: + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); + break; + case Network::ConnectionEvent::LocalClose: + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure); + break; + default: + // Connected is consumed by the connection pool. + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + const Network::Connection* Router::downstreamConnection() const { if (callbacks_ != nullptr) { return callbacks_->connection(); @@ -258,7 +280,10 @@ void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReas method_name_, seq_id_, AppExceptionType::InternalError, fmt::format("too many connections to '{}'", upstream_host_->address()->asString()))}); break; - case Tcp::ConnectionPool::PoolFailureReason::ConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::Timeout: + // TODO(zuercher): distinguish between these cases where appropriate (particularly timeout) if (!response_started_) { parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( method_name_, seq_id_, AppExceptionType::InternalError, diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 340a2a66cc4c6..aa202734b5595 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -98,6 +98,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, // Tcp::ConnectionPool::UpstreamCallbacks void onUpstreamData(Buffer::Instance& data, bool end_stream) override; + void onEvent(Network::ConnectionEvent event) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} private: struct UpstreamRequest : public Tcp::ConnectionPool::Callbacks { diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 0cbd27b204491..9b623f37df2d4 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -276,7 +276,7 @@ INSTANTIATE_TEST_CASE_P(ContainerFieldTypes, ThriftRouterContainerTest, Values(FieldType::Map, FieldType::List, FieldType::Set), fieldTypeParamToString); -TEST_F(ThriftRouterTest, PoolConnectionFailure) { +TEST_F(ThriftRouterTest, PoolRemoteConnectionFailure) { initializeRouter(); startRequest(MessageType::Call); @@ -290,8 +290,43 @@ TEST_F(ThriftRouterTest, PoolConnectionFailure) { EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); })); - conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::ConnectionFailure, - host_ptr_); + conn_pool_callbacks_->onPoolFailure( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure, host_ptr_); +} + +TEST_F(ThriftRouterTest, PoolLocalConnectionFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + conn_pool_callbacks_->onPoolFailure( + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure, host_ptr_); +} + +TEST_F(ThriftRouterTest, PoolTimeout) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::Timeout, host_ptr_); } TEST_F(ThriftRouterTest, PoolOverflowFailure) {