diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v3/route.proto b/api/envoy/extensions/filters/network/thrift_proxy/v3/route.proto index cf4c06ae1f19e..b79c9bc9619ea 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v3/route.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v3/route.proto @@ -81,11 +81,33 @@ message RouteMatch { repeated config.route.v3.HeaderMatcher headers = 4; } -// [#next-free-field: 7] +// [#next-free-field: 8] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.thrift_proxy.v2alpha1.RouteAction"; + // The router is capable of shadowing traffic from one cluster to another. The current + // implementation is "fire and forget," meaning Envoy will not wait for the shadow cluster to + // respond before returning the response from the primary cluster. All normal statistics are + // collected for the shadow cluster making this feature useful for testing. + // + // .. note:: + // + // Shadowing will not be triggered if the primary cluster does not exist. + message RequestMirrorPolicy { + // Specifies the cluster that requests will be mirrored to. The cluster must + // exist in the cluster manager configuration when the route configuration is loaded. + // If it disappears at runtime, the shadow request will silently be ignored. + string cluster = 1 [(validate.rules).string = {min_len: 1}]; + + // If not specified, all requests to the target cluster will be mirrored. + // + // For some fraction N/D, a random number in the range [0,D) is selected. If the + // number is <= the value of the numerator N, or if the key is not present, the default + // value, the request will be mirrored. + config.core.v3.RuntimeFractionalPercent runtime_fraction = 2; + } + oneof cluster_specifier { option (validate.required) = true; @@ -123,6 +145,9 @@ message RouteAction { // Strip the service prefix from the method name, if there's a prefix. For // example, the method call Service:method would end up being just method. bool strip_service_name = 5; + + // Indicates that the route has request mirroring policies. + repeated RequestMirrorPolicy request_mirror_policies = 7; } // Allows for specification of multiple upstream clusters along with weights that indicate the diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto b/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto index e638e9b8a2be8..48caaadf2b75b 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto @@ -81,11 +81,36 @@ message RouteMatch { repeated config.route.v4alpha.HeaderMatcher headers = 4; } -// [#next-free-field: 7] +// [#next-free-field: 8] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.extensions.filters.network.thrift_proxy.v3.RouteAction"; + // The router is capable of shadowing traffic from one cluster to another. The current + // implementation is "fire and forget," meaning Envoy will not wait for the shadow cluster to + // respond before returning the response from the primary cluster. All normal statistics are + // collected for the shadow cluster making this feature useful for testing. + // + // .. note:: + // + // Shadowing will not be triggered if the primary cluster does not exist. + message RequestMirrorPolicy { + option (udpa.annotations.versioning).previous_message_type = + "envoy.extensions.filters.network.thrift_proxy.v3.RouteAction.RequestMirrorPolicy"; + + // Specifies the cluster that requests will be mirrored to. The cluster must + // exist in the cluster manager configuration when the route configuration is loaded. + // If it disappears at runtime, the shadow request will silently be ignored. + string cluster = 1 [(validate.rules).string = {min_len: 1}]; + + // If not specified, all requests to the target cluster will be mirrored. + // + // For some fraction N/D, a random number in the range [0,D) is selected. If the + // number is <= the value of the numerator N, or if the key is not present, the default + // value, the request will be mirrored. + config.core.v4alpha.RuntimeFractionalPercent runtime_fraction = 2; + } + oneof cluster_specifier { option (validate.required) = true; @@ -123,6 +148,9 @@ message RouteAction { // Strip the service prefix from the method name, if there's a prefix. For // example, the method call Service:method would end up being just method. bool strip_service_name = 5; + + // Indicates that the route has request mirroring policies. + repeated RequestMirrorPolicy request_mirror_policies = 7; } // Allows for specification of multiple upstream clusters along with weights that indicate the diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index 1f513cae58091..fbd9d10d49418 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -64,6 +64,7 @@ New Features * jwt_authn: added support for :ref:`Jwt Cache ` and its size can be specified by :ref:`jwt_cache_size `. * listener: new listener metric ``downstream_cx_transport_socket_connect_timeout`` to track transport socket timeouts. * rbac: added :ref:`destination_port_range ` for matching range of destination ports. +* thrift_proxy: added support for :ref:`mirroring requests `. Deprecated ---------- diff --git a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/route.proto b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/route.proto index cf4c06ae1f19e..b79c9bc9619ea 100644 --- a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/route.proto +++ b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/route.proto @@ -81,11 +81,33 @@ message RouteMatch { repeated config.route.v3.HeaderMatcher headers = 4; } -// [#next-free-field: 7] +// [#next-free-field: 8] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.thrift_proxy.v2alpha1.RouteAction"; + // The router is capable of shadowing traffic from one cluster to another. The current + // implementation is "fire and forget," meaning Envoy will not wait for the shadow cluster to + // respond before returning the response from the primary cluster. All normal statistics are + // collected for the shadow cluster making this feature useful for testing. + // + // .. note:: + // + // Shadowing will not be triggered if the primary cluster does not exist. + message RequestMirrorPolicy { + // Specifies the cluster that requests will be mirrored to. The cluster must + // exist in the cluster manager configuration when the route configuration is loaded. + // If it disappears at runtime, the shadow request will silently be ignored. + string cluster = 1 [(validate.rules).string = {min_len: 1}]; + + // If not specified, all requests to the target cluster will be mirrored. + // + // For some fraction N/D, a random number in the range [0,D) is selected. If the + // number is <= the value of the numerator N, or if the key is not present, the default + // value, the request will be mirrored. + config.core.v3.RuntimeFractionalPercent runtime_fraction = 2; + } + oneof cluster_specifier { option (validate.required) = true; @@ -123,6 +145,9 @@ message RouteAction { // Strip the service prefix from the method name, if there's a prefix. For // example, the method call Service:method would end up being just method. bool strip_service_name = 5; + + // Indicates that the route has request mirroring policies. + repeated RequestMirrorPolicy request_mirror_policies = 7; } // Allows for specification of multiple upstream clusters along with weights that indicate the diff --git a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto index e638e9b8a2be8..48caaadf2b75b 100644 --- a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto +++ b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/route.proto @@ -81,11 +81,36 @@ message RouteMatch { repeated config.route.v4alpha.HeaderMatcher headers = 4; } -// [#next-free-field: 7] +// [#next-free-field: 8] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.extensions.filters.network.thrift_proxy.v3.RouteAction"; + // The router is capable of shadowing traffic from one cluster to another. The current + // implementation is "fire and forget," meaning Envoy will not wait for the shadow cluster to + // respond before returning the response from the primary cluster. All normal statistics are + // collected for the shadow cluster making this feature useful for testing. + // + // .. note:: + // + // Shadowing will not be triggered if the primary cluster does not exist. + message RequestMirrorPolicy { + option (udpa.annotations.versioning).previous_message_type = + "envoy.extensions.filters.network.thrift_proxy.v3.RouteAction.RequestMirrorPolicy"; + + // Specifies the cluster that requests will be mirrored to. The cluster must + // exist in the cluster manager configuration when the route configuration is loaded. + // If it disappears at runtime, the shadow request will silently be ignored. + string cluster = 1 [(validate.rules).string = {min_len: 1}]; + + // If not specified, all requests to the target cluster will be mirrored. + // + // For some fraction N/D, a random number in the range [0,D) is selected. If the + // number is <= the value of the numerator N, or if the key is not present, the default + // value, the request will be mirrored. + config.core.v4alpha.RuntimeFractionalPercent runtime_fraction = 2; + } + oneof cluster_specifier { option (validate.required) = true; @@ -123,6 +148,9 @@ message RouteAction { // Strip the service prefix from the method name, if there's a prefix. For // example, the method call Service:method would end up being just method. bool strip_service_name = 5; + + // Indicates that the route has request mirroring policies. + repeated RequestMirrorPolicy request_mirror_policies = 7; } // Allows for specification of multiple upstream clusters along with weights that indicate the diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h index 560952003f92f..08b91a1c4f040 100644 --- a/source/extensions/filters/network/thrift_proxy/metadata.h +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -30,6 +30,71 @@ class MessageMetadata { public: MessageMetadata() = default; + std::shared_ptr clone() const { + auto copy = std::make_shared(); + + if (hasFrameSize()) { + copy->setFrameSize(frameSize()); + } + + if (hasProtocol()) { + copy->setProtocol(protocol()); + } + + if (hasMethodName()) { + copy->setMethodName(methodName()); + } + + if (hasSequenceId()) { + copy->setSequenceId(sequenceId()); + } + + if (hasMessageType()) { + copy->setMessageType(messageType()); + } + + Http::HeaderMapImpl::copyFrom(copy->headers(), headers()); + copy->mutableSpans().assign(spans().begin(), spans().end()); + + if (hasAppException()) { + copy->setAppException(appExceptionType(), appExceptionMessage()); + } + + copy->setProtocolUpgradeMessage(isProtocolUpgradeMessage()); + + auto trace_id = traceId(); + if (trace_id.has_value()) { + copy->setTraceId(trace_id.value()); + } + + auto trace_id_high = traceIdHigh(); + if (trace_id_high.has_value()) { + copy->setTraceIdHigh(trace_id_high.value()); + } + + auto span_id = spanId(); + if (span_id.has_value()) { + copy->setSpanId(span_id.value()); + } + + auto parent_span_id = parentSpanId(); + if (parent_span_id.has_value()) { + copy->setParentSpanId(parent_span_id.value()); + } + + auto flags_opt = flags(); + if (flags_opt.has_value()) { + copy->setFlags(flags_opt.value()); + } + + auto sampled_opt = sampled(); + if (sampled_opt.has_value()) { + copy->setSampled(sampled_opt.value()); + } + + return copy; + } + bool hasFrameSize() const { return frame_size_.has_value(); } uint32_t frameSize() const { return frame_size_.value(); } void setFrameSize(uint32_t size) { frame_size_ = size; } diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD index 7809f34af81b2..048ed66baa952 100644 --- a/source/extensions/filters/network/thrift_proxy/router/BUILD +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -65,6 +65,29 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "shadow_writer_lib", + srcs = ["shadow_writer_impl.cc"], + hdrs = ["shadow_writer_impl.h"], + deps = [ + ":router_interface", + ":upstream_request_lib", + "//envoy/tcp:conn_pool_interface", + "//envoy/upstream:cluster_manager_interface", + "//envoy/upstream:load_balancer_interface", + "//envoy/upstream:thread_local_cluster_interface", + "//source/common/common:linked_object", + "//source/common/upstream:load_balancer_lib", + "//source/extensions/filters/network:well_known_names", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:thrift_object_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", + ], +) + envoy_cc_library( name = "router_lib", srcs = ["router_impl.cc"], @@ -72,6 +95,7 @@ envoy_cc_library( deps = [ ":router_interface", ":router_ratelimit_lib", + ":shadow_writer_lib", ":upstream_request_lib", "//envoy/tcp:conn_pool_interface", "//envoy/upstream:cluster_manager_interface", diff --git a/source/extensions/filters/network/thrift_proxy/router/config.cc b/source/extensions/filters/network/thrift_proxy/router/config.cc index ef94242c89b8d..f749b99dd1558 100644 --- a/source/extensions/filters/network/thrift_proxy/router/config.cc +++ b/source/extensions/filters/network/thrift_proxy/router/config.cc @@ -5,6 +5,7 @@ #include "envoy/registry/registry.h" #include "source/extensions/filters/network/thrift_proxy/router/router_impl.h" +#include "source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h" namespace Envoy { namespace Extensions { @@ -17,9 +18,13 @@ ThriftFilters::FilterFactoryCb RouterFilterConfig::createFilterFactoryFromProtoT const std::string& stat_prefix, Server::Configuration::FactoryContext& context) { UNREFERENCED_PARAMETER(proto_config); - return [&context, stat_prefix](ThriftFilters::FilterChainFactoryCallbacks& callbacks) -> void { - callbacks.addDecoderFilter( - std::make_shared(context.clusterManager(), stat_prefix, context.scope())); + auto shadow_writer = std::make_shared(context.clusterManager(), stat_prefix, + context.scope(), context.dispatcher()); + + return [&context, stat_prefix, + shadow_writer](ThriftFilters::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addDecoderFilter(std::make_shared( + context.clusterManager(), stat_prefix, context.scope(), context.runtime(), *shadow_writer)); }; } diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h index 6d9a3511878a8..dfda11ac34407 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router.h +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -2,6 +2,7 @@ #include #include +#include #include "envoy/buffer/buffer.h" #include "envoy/router/router.h" @@ -22,6 +23,7 @@ namespace ThriftProxy { namespace Router { class RateLimitPolicy; +class RequestMirrorPolicy; /** * RouteEntry is an individual resolved route entry. @@ -55,6 +57,13 @@ class RouteEntry { * @return const Http::LowerCaseString& the header used to determine the cluster. */ virtual const Http::LowerCaseString& clusterHeader() const PURE; + + /** + * @return const std::vector& the mirror policies associated with this route, + * if any. + */ + virtual const std::vector>& + requestMirrorPolicies() const PURE; }; /** @@ -396,6 +405,82 @@ class RequestOwner : public ProtocolConverter, public Logger::Loggable> + submit(const std::string& cluster_name, MessageMetadataSharedPtr metadata, + TransportType original_transport, ProtocolType original_protocol) PURE; +}; + } // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index 46067ec210f7a..45284dea5400e 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -24,7 +24,8 @@ RouteEntryImplBase::RouteEntryImplBase( config_headers_(Http::HeaderUtility::buildHeaderDataVector(route.match().headers())), rate_limit_policy_(route.route().rate_limits()), strip_service_name_(route.route().strip_service_name()), - cluster_header_(route.route().cluster_header()) { + cluster_header_(route.route().cluster_header()), + mirror_policies_(buildMirrorPolicies(route.route())) { if (route.route().has_metadata_match()) { const auto filter_it = route.route().metadata_match().filter_metadata().find( Envoy::Config::MetadataFilters::get().ENVOY_LB); @@ -47,6 +48,21 @@ RouteEntryImplBase::RouteEntryImplBase( } } +std::vector> RouteEntryImplBase::buildMirrorPolicies( + const envoy::extensions::filters::network::thrift_proxy::v3::RouteAction& route) { + std::vector> policies{}; + + const auto& proto_policies = route.request_mirror_policies(); + policies.reserve(proto_policies.size()); + for (const auto& policy : proto_policies) { + policies.push_back(std::make_shared( + policy.cluster(), policy.runtime_fraction().runtime_key(), + policy.runtime_fraction().default_value())); + } + + return policies; +} + const std::string& RouteEntryImplBase::clusterName() const { return cluster_name_; } const RouteEntry* RouteEntryImplBase::routeEntry() const { return this; } @@ -187,6 +203,10 @@ void Router::onDestroy() { upstream_request_->resetStream(); cleanup(); } + + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().onRouterDestroy(); + } } void Router::setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) { @@ -245,6 +265,23 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { auto& upstream_req_info = prepare_result.upstream_request_info.value(); passthrough_supported_ = upstream_req_info.passthrough_supported; + + // Prepare connections for shadow routers, if there are mirror policies configured and currently + // enabled. + const auto& policies = route_entry_->requestMirrorPolicies(); + if (!policies.empty()) { + for (const auto& policy : policies) { + if (policy->enabled(runtime_)) { + auto shadow_router = + shadow_writer_.submit(policy->clusterName(), metadata, upstream_req_info.transport, + upstream_req_info.protocol); + if (shadow_router.has_value()) { + shadow_routers_.push_back(shadow_router.value()); + } + } + } + } + upstream_request_ = std::make_unique(*this, *upstream_req_info.conn_pool_data, metadata, upstream_req_info.transport, upstream_req_info.protocol); @@ -253,11 +290,166 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { FilterStatus Router::messageEnd() { ProtocolConverter::messageEnd(); - request_size_ += upstream_request_->encodeAndWrite(upstream_request_buffer_); + const auto encode_size = upstream_request_->encodeAndWrite(upstream_request_buffer_); + addSize(encode_size); recordUpstreamRequestSize(*cluster_, request_size_); + + // Dispatch shadow requests, if any. + // Note: if connections aren't ready, the write will happen when appropriate. + for (auto& shadow_router : shadow_routers_) { + auto& router = shadow_router.get(); + router.requestOwner().messageEnd(); + } + return FilterStatus::Continue; } +FilterStatus Router::passthroughData(Buffer::Instance& data) { + for (auto& shadow_router : shadow_routers_) { + Buffer::OwnedImpl shadow_data; + shadow_data.add(data); + shadow_router.get().requestOwner().passthroughData(shadow_data); + } + + return ProtocolConverter::passthroughData(data); +} + +FilterStatus Router::structBegin(absl::string_view name) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().structBegin(name); + } + + return ProtocolConverter::structBegin(name); +} + +FilterStatus Router::structEnd() { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().structEnd(); + } + + return ProtocolConverter::structEnd(); +} + +FilterStatus Router::fieldBegin(absl::string_view name, FieldType& field_type, int16_t& field_id) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().fieldBegin(name, field_type, field_id); + } + + return ProtocolConverter::fieldBegin(name, field_type, field_id); +} + +FilterStatus Router::fieldEnd() { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().fieldEnd(); + } + + return ProtocolConverter::fieldEnd(); +} + +FilterStatus Router::boolValue(bool& value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().boolValue(value); + } + + return ProtocolConverter::boolValue(value); +} + +FilterStatus Router::byteValue(uint8_t& value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().byteValue(value); + } + + return ProtocolConverter::byteValue(value); +} + +FilterStatus Router::int16Value(int16_t& value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().int16Value(value); + } + + return ProtocolConverter::int16Value(value); +} + +FilterStatus Router::int32Value(int32_t& value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().int32Value(value); + } + + return ProtocolConverter::int32Value(value); +} + +FilterStatus Router::int64Value(int64_t& value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().int64Value(value); + } + + return ProtocolConverter::int64Value(value); +} + +FilterStatus Router::doubleValue(double& value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().doubleValue(value); + } + + return ProtocolConverter::doubleValue(value); +} + +FilterStatus Router::stringValue(absl::string_view value) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().stringValue(value); + } + + return ProtocolConverter::stringValue(value); +} + +FilterStatus Router::mapBegin(FieldType& key_type, FieldType& value_type, uint32_t& size) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().mapBegin(key_type, value_type, size); + } + + return ProtocolConverter::mapBegin(key_type, value_type, size); +} + +FilterStatus Router::mapEnd() { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().mapEnd(); + } + + return ProtocolConverter::mapEnd(); +} + +FilterStatus Router::listBegin(FieldType& elem_type, uint32_t& size) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().listBegin(elem_type, size); + } + + return ProtocolConverter::listBegin(elem_type, size); +} + +FilterStatus Router::listEnd() { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().listEnd(); + } + + return ProtocolConverter::listEnd(); +} + +FilterStatus Router::setBegin(FieldType& elem_type, uint32_t& size) { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().setBegin(elem_type, size); + } + + return ProtocolConverter::setBegin(elem_type, size); +} + +FilterStatus Router::setEnd() { + for (auto& shadow_router : shadow_routers_) { + shadow_router.get().requestOwner().setEnd(); + } + + return ProtocolConverter::setEnd(); +} + void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { const bool done = upstream_request_->handleUpstreamData(data, end_stream, *this, *upstream_response_callbacks_); diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 25575ecda7761..0b60226f57dce 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -28,6 +28,25 @@ namespace NetworkFilters { namespace ThriftProxy { namespace Router { +class RequestMirrorPolicyImpl : public RequestMirrorPolicy { +public: + RequestMirrorPolicyImpl(const std::string& cluster_name, const std::string& runtime_key, + const envoy::type::v3::FractionalPercent& default_value) + : cluster_name_(cluster_name), runtime_key_(runtime_key), default_value_(default_value) {} + + // Router::RequestMirrorPolicy + const std::string& clusterName() const override { return cluster_name_; } + bool enabled(Runtime::Loader& runtime) const override { + return runtime_key_.empty() ? true + : runtime.snapshot().featureEnabled(runtime_key_, default_value_); + } + +private: + const std::string cluster_name_; + const std::string runtime_key_; + const envoy::type::v3::FractionalPercent default_value_; +}; + class RouteEntryImplBase : public RouteEntry, public Route, public std::enable_shared_from_this { @@ -42,6 +61,9 @@ class RouteEntryImplBase : public RouteEntry, const RateLimitPolicy& rateLimitPolicy() const override { return rate_limit_policy_; } bool stripServiceName() const override { return strip_service_name_; }; const Http::LowerCaseString& clusterHeader() const override { return cluster_header_; } + const std::vector>& requestMirrorPolicies() const override { + return mirror_policies_; + } // Router::Route const RouteEntry* routeEntry() const override; @@ -75,6 +97,10 @@ class RouteEntryImplBase : public RouteEntry, const RateLimitPolicy& rateLimitPolicy() const override { return parent_.rateLimitPolicy(); } bool stripServiceName() const override { return parent_.stripServiceName(); } const Http::LowerCaseString& clusterHeader() const override { return parent_.clusterHeader(); } + const std::vector>& + requestMirrorPolicies() const override { + return parent_.requestMirrorPolicies(); + } // Router::Route const RouteEntry* routeEntry() const override { return this; } @@ -100,6 +126,10 @@ class RouteEntryImplBase : public RouteEntry, const RateLimitPolicy& rateLimitPolicy() const override { return parent_.rateLimitPolicy(); } bool stripServiceName() const override { return parent_.stripServiceName(); } const Http::LowerCaseString& clusterHeader() const override { return parent_.clusterHeader(); } + const std::vector>& + requestMirrorPolicies() const override { + return parent_.requestMirrorPolicies(); + } // Router::Route const RouteEntry* routeEntry() const override { return this; } @@ -109,6 +139,9 @@ class RouteEntryImplBase : public RouteEntry, const std::string cluster_name_; }; + static std::vector> buildMirrorPolicies( + const envoy::extensions::filters::network::thrift_proxy::v3::RouteAction& route); + const std::string cluster_name_; const std::vector config_headers_; std::vector weighted_clusters_; @@ -117,6 +150,7 @@ class RouteEntryImplBase : public RouteEntry, const RateLimitPolicyImpl rate_limit_policy_; const bool strip_service_name_; const Http::LowerCaseString cluster_header_; + const std::vector> mirror_policies_; }; using RouteEntryImplBaseConstSharedPtr = std::shared_ptr; @@ -184,8 +218,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, public ThriftFilters::DecoderFilter { public: Router(Upstream::ClusterManager& cluster_manager, const std::string& stat_prefix, - Stats::Scope& scope) - : RequestOwner(cluster_manager, stat_prefix, scope), passthrough_supported_(false) {} + Stats::Scope& scope, Runtime::Loader& runtime, ShadowWriter& shadow_writer) + : RequestOwner(cluster_manager, stat_prefix, scope), passthrough_supported_(false), + runtime_(runtime), shadow_writer_(shadow_writer) {} ~Router() override = default; @@ -213,6 +248,25 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, FilterStatus transportEnd() override; FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override; FilterStatus messageEnd() override; + FilterStatus passthroughData(Buffer::Instance& data) override; + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType& field_type, + int16_t& field_id) override; + FilterStatus fieldEnd() override; + FilterStatus boolValue(bool& value) override; + FilterStatus byteValue(uint8_t& value) override; + FilterStatus int16Value(int16_t& value) override; + FilterStatus int32Value(int32_t& value) override; + FilterStatus int64Value(int64_t& value) override; + FilterStatus doubleValue(double& value) override; + FilterStatus stringValue(absl::string_view value) override; + FilterStatus mapBegin(FieldType& key_type, FieldType& value_type, uint32_t& size) override; + FilterStatus mapEnd() override; + FilterStatus listBegin(FieldType& elem_type, uint32_t& size) override; + FilterStatus listEnd() override; + FilterStatus setBegin(FieldType& elem_type, uint32_t& size) override; + FilterStatus setEnd() override; // Upstream::LoadBalancerContext const Network::Connection* downstreamConnection() const override; @@ -239,6 +293,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, bool passthrough_supported_ : 1; uint64_t request_size_{}; + Runtime::Loader& runtime_; + ShadowWriter& shadow_writer_; + std::vector> shadow_routers_{}; }; } // namespace Router diff --git a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc new file mode 100644 index 0000000000000..f4f95566f502a --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc @@ -0,0 +1,357 @@ +#include "source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h" + +#include + +#include "envoy/upstream/cluster_manager.h" +#include "envoy/upstream/thread_local_cluster.h" + +#include "source/common/common/utility.h" +#include "source/extensions/filters/network/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +absl::optional> +ShadowWriterImpl::submit(const std::string& cluster_name, MessageMetadataSharedPtr metadata, + TransportType original_transport, ProtocolType original_protocol) { + auto shadow_router = std::make_unique(*this, cluster_name, metadata, + original_transport, original_protocol); + const bool created = shadow_router->createUpstreamRequest(); + if (!created) { + return absl::nullopt; + } + + LinkedList::moveIntoList(std::move(shadow_router), active_routers_); + return *active_routers_.front(); +} + +ShadowRouterImpl::ShadowRouterImpl(ShadowWriterImpl& parent, const std::string& cluster_name, + MessageMetadataSharedPtr& metadata, TransportType transport_type, + ProtocolType protocol_type) + : RequestOwner(parent.clusterManager(), parent.statPrefix(), parent.scope()), parent_(parent), + cluster_name_(cluster_name), metadata_(metadata->clone()), transport_type_(transport_type), + protocol_type_(protocol_type), + transport_(NamedTransportConfigFactory::getFactory(transport_type).createTransport()), + protocol_(NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol()) { + response_decoder_ = std::make_unique(*transport_, *protocol_); + upstream_response_callbacks_ = + std::make_unique(*response_decoder_); +} + +Event::Dispatcher& ShadowRouterImpl::dispatcher() { return parent_.dispatcher(); } + +bool ShadowRouterImpl::createUpstreamRequest() { + auto prepare_result = + prepareUpstreamRequest(cluster_name_, metadata_, transport_type_, protocol_type_, this); + if (prepare_result.exception.has_value()) { + return false; + } + + auto& upstream_req_info = prepare_result.upstream_request_info.value(); + + upstream_request_ = + std::make_unique(*this, *upstream_req_info.conn_pool_data, metadata_, + upstream_req_info.transport, upstream_req_info.protocol); + upstream_request_->start(); + return true; +} + +bool ShadowRouterImpl::requestStarted() const { + return upstream_request_->conn_data_ != nullptr && + upstream_request_->upgrade_response_ == nullptr; +} + +FilterStatus ShadowRouterImpl::passthroughData(Buffer::Instance& data) { + if (requestStarted()) { + return ProtocolConverter::passthroughData(data); + } + + auto copied = std::make_shared(data); + auto cb = [copied = std::move(copied), this]() mutable { + ProtocolConverter::passthroughData(*copied); + }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::structBegin(absl::string_view name) { + if (requestStarted()) { + return ProtocolConverter::structBegin(name); + } + + auto cb = [name_str = std::string(name), this]() { + ProtocolConverter::structBegin(absl::string_view(name_str)); + }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::structEnd() { + if (requestStarted()) { + return ProtocolConverter::structEnd(); + } + + auto cb = [this]() { ProtocolConverter::structEnd(); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::fieldBegin(absl::string_view name, FieldType& field_type, + int16_t& field_id) { + if (requestStarted()) { + return ProtocolConverter::fieldBegin(name, field_type, field_id); + } + + auto cb = [name_str = std::string(name), field_type, field_id, this]() mutable { + ProtocolConverter::fieldBegin(absl::string_view(name_str), field_type, field_id); + }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::fieldEnd() { + if (requestStarted()) { + return ProtocolConverter::fieldEnd(); + } + + auto cb = [this]() { ProtocolConverter::fieldEnd(); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::boolValue(bool& value) { + if (requestStarted()) { + return ProtocolConverter::boolValue(value); + } + + auto cb = [value, this]() mutable { ProtocolConverter::boolValue(value); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::byteValue(uint8_t& value) { + if (requestStarted()) { + return ProtocolConverter::byteValue(value); + } + + auto cb = [value, this]() mutable { ProtocolConverter::byteValue(value); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::int16Value(int16_t& value) { + if (requestStarted()) { + return ProtocolConverter::int16Value(value); + } + + auto cb = [value, this]() mutable { ProtocolConverter::int16Value(value); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::int32Value(int32_t& value) { + if (requestStarted()) { + return ProtocolConverter::int32Value(value); + } + + auto cb = [value, this]() mutable { ProtocolConverter::int32Value(value); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::int64Value(int64_t& value) { + if (requestStarted()) { + return ProtocolConverter::int64Value(value); + } + + auto cb = [value, this]() mutable { ProtocolConverter::int64Value(value); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::doubleValue(double& value) { + if (requestStarted()) { + return ProtocolConverter::doubleValue(value); + } + + auto cb = [value, this]() mutable { ProtocolConverter::doubleValue(value); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::stringValue(absl::string_view value) { + if (requestStarted()) { + return ProtocolConverter::stringValue(value); + } + + auto cb = [value_str = std::string(value), this]() { + ProtocolConverter::stringValue(absl::string_view(value_str)); + }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::mapBegin(FieldType& key_type, FieldType& value_type, + uint32_t& size) { + if (requestStarted()) { + return ProtocolConverter::mapBegin(key_type, value_type, size); + } + + auto cb = [key_type, value_type, size, this]() mutable { + ProtocolConverter::mapBegin(key_type, value_type, size); + }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::mapEnd() { + if (requestStarted()) { + return ProtocolConverter::mapEnd(); + } + + auto cb = [this]() { ProtocolConverter::mapEnd(); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::listBegin(FieldType& elem_type, uint32_t& size) { + if (requestStarted()) { + return ProtocolConverter::listBegin(elem_type, size); + } + + auto cb = [elem_type, size, this]() mutable { ProtocolConverter::listBegin(elem_type, size); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::listEnd() { + if (requestStarted()) { + return ProtocolConverter::listEnd(); + } + + auto cb = [this]() { ProtocolConverter::listEnd(); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::setBegin(FieldType& elem_type, uint32_t& size) { + if (requestStarted()) { + return ProtocolConverter::setBegin(elem_type, size); + } + + auto cb = [elem_type, size, this]() mutable { ProtocolConverter::setBegin(elem_type, size); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::setEnd() { + if (requestStarted()) { + return ProtocolConverter::setEnd(); + } + + auto cb = [this]() { ProtocolConverter::setEnd(); }; + pending_callbacks_.push_back(std::move(cb)); + + return FilterStatus::Continue; +} + +FilterStatus ShadowRouterImpl::messageEnd() { + auto cb = [this]() { + ASSERT(upstream_request_->conn_data_ != nullptr); + + ProtocolConverter::messageEnd(); + const auto encode_size = upstream_request_->encodeAndWrite(upstream_request_buffer_); + addSize(encode_size); + recordUpstreamRequestSize(*cluster_, request_size_); + + request_sent_ = true; + + if (metadata_->messageType() == MessageType::Oneway) { + upstream_request_->releaseConnection(false); + } + }; + + if (requestStarted()) { + cb(); + } else { + request_ready_ = true; + pending_callbacks_.push_back(std::move(cb)); + } + + return FilterStatus::Continue; +} + +bool ShadowRouterImpl::requestInProgress() { + const bool connection_open = upstream_request_->conn_data_ != nullptr; + const bool connection_waiting = upstream_request_->conn_pool_handle_ != nullptr; + + // Connection open and message sent. + const bool message_sent = connection_open && request_sent_; + + // Request ready to go and connection ready or almost ready. + const bool message_ready = request_ready_ && (connection_open || connection_waiting); + + return message_sent || message_ready; +} + +void ShadowRouterImpl::onRouterDestroy() { + // Mark the shadow request to be destroyed when the response gets back + // or the upstream connection finally fails. + router_destroyed_ = true; + + if (!requestInProgress()) { + maybeCleanup(); + } +} + +bool ShadowRouterImpl::waitingForConnection() const { + return upstream_request_->conn_pool_handle_ != nullptr; +} + +void ShadowRouterImpl::maybeCleanup() { + if (router_destroyed_) { + upstream_request_.reset(); + if (inserted()) { + removeFromList(parent_.active_routers_); + } + } +} + +void ShadowRouterImpl::onUpstreamData(Buffer::Instance& data, bool end_stream) { + const bool done = + upstream_request_->handleUpstreamData(data, end_stream, *this, *upstream_response_callbacks_); + if (done) { + maybeCleanup(); + } +} + +void ShadowRouterImpl::onEvent(Network::ConnectionEvent event) { + upstream_request_->onEvent(event); + maybeCleanup(); +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h new file mode 100644 index 0000000000000..c2f193b0d5605 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h @@ -0,0 +1,267 @@ +#pragma once + +#include + +#include "envoy/event/dispatcher.h" +#include "envoy/router/router.h" +#include "envoy/stats/scope.h" +#include "envoy/stats/stats_macros.h" +#include "envoy/tcp/conn_pool.h" +#include "envoy/upstream/load_balancer.h" + +#include "source/common/common/linked_object.h" +#include "source/common/common/logger.h" +#include "source/common/upstream/load_balancer_impl.h" +#include "source/extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "source/extensions/filters/network/thrift_proxy/conn_manager.h" +#include "source/extensions/filters/network/thrift_proxy/router/router.h" +#include "source/extensions/filters/network/thrift_proxy/router/upstream_request.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +struct NullResponseDecoder : public DecoderCallbacks, public ProtocolConverter { + NullResponseDecoder(Transport& transport, Protocol& protocol) + : decoder_(std::make_unique(transport, protocol, *this)) { + initProtocolConverter(protocol, response_buffer_); + } + + virtual ThriftFilters::ResponseStatus upstreamData(Buffer::Instance& data) { + upstream_buffer_.move(data); + + bool underflow = false; + try { + underflow = onData(); + } catch (const AppException&) { + return ThriftFilters::ResponseStatus::Reset; + } catch (const EnvoyException&) { + return ThriftFilters::ResponseStatus::Reset; + } + + ASSERT(complete_ || underflow); + return complete_ ? ThriftFilters::ResponseStatus::Complete + : ThriftFilters::ResponseStatus::MoreData; + } + virtual bool onData() { + bool underflow = false; + decoder_->onData(upstream_buffer_, underflow); + return underflow; + } + MessageMetadataSharedPtr& responseMetadata() { return metadata_; } + bool responseSuccess() { return success_.value_or(false); } + + // ProtocolConverter + FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override { + metadata_ = metadata; + first_reply_field_ = + (metadata->hasMessageType() && metadata->messageType() == MessageType::Reply); + return FilterStatus::Continue; + } + FilterStatus messageEnd() override { + if (first_reply_field_) { + success_ = true; + first_reply_field_ = false; + } + return FilterStatus::Continue; + } + FilterStatus fieldBegin(absl::string_view, FieldType&, int16_t& field_id) override { + if (first_reply_field_) { + success_ = (field_id == 0); + first_reply_field_ = false; + } + return FilterStatus::Continue; + } + FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override { + UNREFERENCED_PARAMETER(metadata); + return FilterStatus::Continue; + } + FilterStatus transportEnd() override { + ASSERT(metadata_ != nullptr); + complete_ = true; + return FilterStatus::Continue; + } + + // DecoderCallbacks + DecoderEventHandler& newDecoderEventHandler() override { return *this; } + bool passthroughEnabled() const override { return false; } + + DecoderPtr decoder_; + Buffer::OwnedImpl response_buffer_; + Buffer::OwnedImpl upstream_buffer_; + MessageMetadataSharedPtr metadata_; + absl::optional success_; + bool complete_ : 1; + bool first_reply_field_ : 1; +}; +using NullResponseDecoderPtr = std::unique_ptr; + +// Adapter from NullResponseDecoder to UpstreamResponseCallbacks. +class ShadowUpstreamResponseCallbacksImpl : public UpstreamResponseCallbacks { +public: + ShadowUpstreamResponseCallbacksImpl(NullResponseDecoder& response_decoder) + : response_decoder_(response_decoder) {} + + void startUpstreamResponse(Transport&, Protocol&) override {} + ThriftFilters::ResponseStatus upstreamData(Buffer::Instance& buffer) override { + return response_decoder_.upstreamData(buffer); + } + MessageMetadataSharedPtr responseMetadata() override { + return response_decoder_.responseMetadata(); + } + bool responseSuccess() override { return response_decoder_.responseSuccess(); } + +private: + NullResponseDecoder& response_decoder_; +}; +using ShadowUpstreamResponseCallbacksImplPtr = std::unique_ptr; + +class ShadowWriterImpl; + +class ShadowRouterImpl : public ShadowRouterHandle, + public RequestOwner, + public Tcp::ConnectionPool::UpstreamCallbacks, + public Upstream::LoadBalancerContextBase, + public Event::DeferredDeletable, + public LinkedObject { +public: + ShadowRouterImpl(ShadowWriterImpl& parent, const std::string& cluster_name, + MessageMetadataSharedPtr& metadata, TransportType transport_type, + ProtocolType protocol_type); + ~ShadowRouterImpl() override = default; + + bool createUpstreamRequest(); + void maybeCleanup(); + void resetStream() { + if (upstream_request_ != nullptr) { + upstream_request_->releaseConnection(true); + } + } + + // ShadowRouterHandle + void onRouterDestroy() override; + bool waitingForConnection() const override; + RequestOwner& requestOwner() override { return *this; } + + // RequestOwner + Tcp::ConnectionPool::UpstreamCallbacks& upstreamCallbacks() override { return *this; } + Buffer::OwnedImpl& buffer() override { return upstream_request_buffer_; } + Event::Dispatcher& dispatcher() override; + void addSize(uint64_t size) override { request_size_ += size; } + void continueDecoding() override { + if (pending_callbacks_.empty()) { + return; + } + + for (auto& cb : pending_callbacks_) { + cb(); + } + } + void resetDownstreamConnection() override {} + void sendLocalReply(const ThriftProxy::DirectResponse&, bool) override {} + void recordResponseDuration(uint64_t value, Stats::Histogram::Unit unit) override { + recordClusterResponseDuration(*cluster_, value, unit); + } + + // RequestOwner::ProtocolConverter + FilterStatus transportBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus transportEnd() override { return FilterStatus::Continue; } + FilterStatus messageEnd() override; + FilterStatus passthroughData(Buffer::Instance& data) override; + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType& field_type, + int16_t& field_id) override; + FilterStatus fieldEnd() override; + FilterStatus boolValue(bool& value) override; + FilterStatus byteValue(uint8_t& value) override; + FilterStatus int16Value(int16_t& value) override; + FilterStatus int32Value(int32_t& value) override; + FilterStatus int64Value(int64_t& value) override; + FilterStatus doubleValue(double& value) override; + FilterStatus stringValue(absl::string_view value) override; + FilterStatus mapBegin(FieldType& key_type, FieldType& value_type, uint32_t& size) override; + FilterStatus mapEnd() override; + FilterStatus listBegin(FieldType& elem_type, uint32_t& size) override; + FilterStatus listEnd() override; + FilterStatus setBegin(FieldType& elem_type, uint32_t& size) override; + FilterStatus setEnd() override; + + // Tcp::ConnectionPool::UpstreamCallbacks + void onUpstreamData(Buffer::Instance& data, bool end_stream) override; + void onEvent(Network::ConnectionEvent event) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + // Upstream::LoadBalancerContextBase + const Network::Connection* downstreamConnection() const override { return nullptr; } + const Envoy::Router::MetadataMatchCriteria* metadataMatchCriteria() override { return nullptr; } + +private: + friend class ShadowWriterTest; + + void writeRequest(); + bool requestInProgress(); + bool requestStarted() const; + + ShadowWriterImpl& parent_; + const std::string cluster_name_; + MessageMetadataSharedPtr metadata_; + const TransportType transport_type_; + const ProtocolType protocol_type_; + TransportPtr transport_; + ProtocolPtr protocol_; + NullResponseDecoderPtr response_decoder_; + ShadowUpstreamResponseCallbacksImplPtr upstream_response_callbacks_; + bool router_destroyed_{}; + bool request_sent_{}; + Buffer::OwnedImpl upstream_request_buffer_; + std::unique_ptr upstream_request_; + uint64_t request_size_{}; + uint64_t response_size_{}; + bool request_ready_ : 1; + + using ConverterCallback = std::function; + std::list pending_callbacks_; +}; + +class ShadowWriterImpl : public ShadowWriter, Logger::Loggable { +public: + ShadowWriterImpl(Upstream::ClusterManager& cm, const std::string& stat_prefix, + Stats::Scope& scope, Event::Dispatcher& dispatcher) + : cm_(cm), stat_prefix_(stat_prefix), scope_(scope), dispatcher_(dispatcher) {} + + ~ShadowWriterImpl() override { + while (!active_routers_.empty()) { + auto& router = active_routers_.front(); + router->resetStream(); + router->onRouterDestroy(); + } + } + + // Router::ShadowWriter + Upstream::ClusterManager& clusterManager() override { return cm_; } + const std::string& statPrefix() const override { return stat_prefix_; } + Stats::Scope& scope() override { return scope_; } + Event::Dispatcher& dispatcher() override { return dispatcher_; } + absl::optional> + submit(const std::string& cluster_name, MessageMetadataSharedPtr metadata, + TransportType original_transport, ProtocolType original_protocol) override; + +private: + friend class ShadowRouterImpl; + + Upstream::ClusterManager& cm_; + const std::string stat_prefix_; + Stats::Scope& scope_; + Event::Dispatcher& dispatcher_; + std::list> active_routers_; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index 6e6cb31a06288..cd3fe562efe15 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -28,6 +28,7 @@ envoy_extension_cc_mock( "//source/extensions/filters/network/thrift_proxy/router:router_ratelimit_interface", "//test/mocks/network:network_mocks", "//test/mocks/stream_info:stream_info_mocks", + "//test/mocks/upstream:upstream_mocks", "//test/test_common:printers_lib", ], ) @@ -270,6 +271,7 @@ envoy_extension_cc_test( "//source/extensions/filters/network/thrift_proxy:config", "//source/extensions/filters/network/thrift_proxy/router:config", "//source/extensions/filters/network/thrift_proxy/router:router_lib", + "//source/extensions/filters/network/thrift_proxy/router:shadow_writer_lib", "//test/mocks/network:network_mocks", "//test/mocks/server:factory_context_mocks", "//test/mocks/upstream:host_mocks", @@ -360,3 +362,22 @@ envoy_extension_cc_test( "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v3:pkg_cc_proto", ], ) + +envoy_extension_cc_test( + name = "shadow_writer_test", + srcs = ["shadow_writer_test.cc"], + extension_names = ["envoy.filters.network.thrift_proxy"], + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy/router:shadow_writer_lib", + "//test/mocks/network:network_mocks", + "//test/mocks/server:factory_context_mocks", + "//test/mocks/upstream:host_mocks", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v3:pkg_cc_proto", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index 8e93ace6f0c3e..e4a795a6fd0ce 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -130,12 +130,18 @@ MockRouteEntry::MockRouteEntry() { ON_CALL(*this, clusterName()).WillByDefault(ReturnRef(cluster_name_)); ON_CALL(*this, rateLimitPolicy()).WillByDefault(ReturnRef(rate_limit_policy_)); ON_CALL(*this, clusterHeader()).WillByDefault(ReturnRef(cluster_header_)); + ON_CALL(*this, requestMirrorPolicies()).WillByDefault(ReturnRef(policies_)); } MockRouteEntry::~MockRouteEntry() = default; MockRoute::MockRoute() { ON_CALL(*this, routeEntry()).WillByDefault(Return(&route_entry_)); } MockRoute::~MockRoute() = default; +MockShadowWriter::MockShadowWriter() { + ON_CALL(*this, submit(_, _, _, _)).WillByDefault(Return(router_handle_)); +} +MockShadowWriter::~MockShadowWriter() = default; + } // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index b55d4bc466a9c..b3eddda2cb352 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -14,6 +14,7 @@ #include "test/mocks/network/mocks.h" #include "test/mocks/stream_info/mocks.h" +#include "test/mocks/upstream/cluster_manager.h" #include "test/test_common/printers.h" #include "gmock/gmock.h" @@ -333,10 +334,13 @@ class MockRouteEntry : public RouteEntry { MOCK_METHOD(RateLimitPolicy&, rateLimitPolicy, (), (const)); MOCK_METHOD(bool, stripServiceName, (), (const)); MOCK_METHOD(const Http::LowerCaseString&, clusterHeader, (), (const)); + MOCK_METHOD(const std::vector>&, requestMirrorPolicies, (), + (const)); std::string cluster_name_{"fake_cluster"}; Http::LowerCaseString cluster_header_{""}; NiceMock rate_limit_policy_; + std::vector> policies_; }; class MockRoute : public Route { @@ -350,6 +354,21 @@ class MockRoute : public Route { NiceMock route_entry_; }; +class MockShadowWriter : public ShadowWriter { +public: + MockShadowWriter(); + ~MockShadowWriter() override; + + MOCK_METHOD(Upstream::ClusterManager&, clusterManager, (), ()); + MOCK_METHOD(std::string&, statPrefix, (), (const)); + MOCK_METHOD(Stats::Scope&, scope, (), ()); + MOCK_METHOD(Event::Dispatcher&, dispatcher, (), ()); + MOCK_METHOD(absl::optional>, submit, + (const std::string&, MessageMetadataSharedPtr, TransportType, ProtocolType), ()); + + absl::optional> router_handle_{absl::nullopt}; +}; + } // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/test/extensions/filters/network/thrift_proxy/router_ratelimit_test.cc b/test/extensions/filters/network/thrift_proxy/router_ratelimit_test.cc index 15c19bd2be7bb..cb44c2df41349 100644 --- a/test/extensions/filters/network/thrift_proxy/router_ratelimit_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_ratelimit_test.cc @@ -47,8 +47,8 @@ class ThriftRateLimitConfigurationTest : public testing::Test { return *metadata_; } - std::unique_ptr config_; NiceMock factory_context_; + std::unique_ptr config_; Network::Address::Ipv4Instance default_remote_address_{"10.0.0.1"}; MessageMetadataSharedPtr metadata_; }; diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 1b38ade61f648..773b7fbf936fa 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -9,6 +9,7 @@ #include "source/extensions/filters/network/thrift_proxy/config.h" #include "source/extensions/filters/network/thrift_proxy/router/config.h" #include "source/extensions/filters/network/thrift_proxy/router/router_impl.h" +#include "source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" @@ -67,43 +68,72 @@ class ThriftRouterTestBase { public: ThriftRouterTestBase() : transport_factory_([&]() -> MockTransport* { - ASSERT(transport_ == nullptr); - transport_ = new NiceMock(); - if (mock_transport_cb_) { - mock_transport_cb_(transport_); + // Create shadow transports. + auto transport = new NiceMock(); + transports_requested_++; + + // Ignore null response decoder transports. + bool is_response_transport = shadow_writer_impl_ != nullptr && + (transports_requested_ == 1 || transports_requested_ == 3); + if (!is_response_transport) { + if (mock_transport_cb_) { + mock_transport_cb_(transport); + } + all_transports_.push_back(transport); + transport_ = transport; } - return transport_; + + return transport; }), protocol_factory_([&]() -> MockProtocol* { - ASSERT(protocol_ == nullptr); - protocol_ = new NiceMock(); - if (mock_protocol_cb_) { - mock_protocol_cb_(protocol_); + // Create shadow protocols. + auto protocol = new NiceMock(); + protocols_requested_++; + + // Ditto for protocols. + bool is_response_protocol = shadow_writer_impl_ != nullptr && + (protocols_requested_ == 1 || protocols_requested_ == 3); + if (!is_response_protocol) { + if (mock_protocol_cb_) { + mock_protocol_cb_(protocol); + } + all_protocols_.push_back(protocol); + protocol_ = protocol; } - return protocol_; + + return protocol; }), transport_register_(transport_factory_), protocol_register_(protocol_factory_) { context_.cluster_manager_.initializeThreadLocalClusters({"cluster"}); } - void initializeRouter() { + void initializeRouter(bool use_real_shadow_writer = false) { route_ = new NiceMock(); route_ptr_.reset(route_); - router_ = std::make_unique(context_.clusterManager(), "test", context_.scope()); + if (!use_real_shadow_writer) { + router_ = std::make_unique(context_.clusterManager(), "test", context_.scope(), + context_.runtime(), shadow_writer_); + } else { + shadow_writer_impl_ = std::make_shared(context_.clusterManager(), "test", + context_.scope(), dispatcher_); + router_ = std::make_unique(context_.clusterManager(), "test", context_.scope(), + context_.runtime(), *shadow_writer_impl_); + } EXPECT_EQ(nullptr, router_->downstreamConnection()); router_->setDecoderFilterCallbacks(callbacks_); } - void initializeMetadata(MessageType msg_type, std::string method = "method") { + void initializeMetadata(MessageType msg_type, std::string method = "method", + int32_t sequence_id = 1) { msg_type_ = msg_type; metadata_ = std::make_shared(); metadata_->setMethodName(method); metadata_->setMessageType(msg_type_); - metadata_->setSequenceId(1); + metadata_->setSequenceId(sequence_id); } void startRequest(MessageType msg_type, std::string method = "method", @@ -170,14 +200,14 @@ class ThriftRouterTestBase { EXPECT_NE(nullptr, upstream_callbacks_); } - void startRequestWithExistingConnection(MessageType msg_type) { + void startRequestWithExistingConnection(MessageType msg_type, int32_t sequence_id = 1) { EXPECT_EQ(FilterStatus::Continue, router_->transportBegin({})); EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); - initializeMetadata(msg_type); + initializeMetadata(msg_type, "method", sequence_id); EXPECT_CALL(*context_.cluster_manager_.thread_local_cluster_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) @@ -233,20 +263,28 @@ class ThriftRouterTestBase { } void sendTrivialStruct(FieldType field_type) { - EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeStructBegin(_, "")); + } EXPECT_EQ(FilterStatus::Continue, router_->structBegin({})); int16_t id = 1; - EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, id)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeFieldBegin(_, "", field_type, id)); + } EXPECT_EQ(FilterStatus::Continue, router_->fieldBegin({}, field_type, id)); sendTrivialValue(field_type); - EXPECT_CALL(*protocol_, writeFieldEnd(_)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeFieldEnd(_)); + } EXPECT_EQ(FilterStatus::Continue, router_->fieldEnd()); - EXPECT_CALL(*protocol_, writeFieldBegin(_, "", FieldType::Stop, 0)); - EXPECT_CALL(*protocol_, writeStructEnd(_)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeFieldBegin(_, "", FieldType::Stop, 0)); + EXPECT_CALL(*protocol, writeStructEnd(_)); + } EXPECT_EQ(FilterStatus::Continue, router_->structEnd()); } @@ -254,37 +292,51 @@ class ThriftRouterTestBase { switch (field_type) { case FieldType::Bool: { bool v = true; - EXPECT_CALL(*protocol_, writeBool(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeBool(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->boolValue(v)); } break; case FieldType::Byte: { uint8_t v = 2; - EXPECT_CALL(*protocol_, writeByte(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeByte(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->byteValue(v)); } break; case FieldType::I16: { int16_t v = 3; - EXPECT_CALL(*protocol_, writeInt16(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt16(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->int16Value(v)); } break; case FieldType::I32: { int32_t v = 4; - EXPECT_CALL(*protocol_, writeInt32(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt32(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->int32Value(v)); } break; case FieldType::I64: { int64_t v = 5; - EXPECT_CALL(*protocol_, writeInt64(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt64(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->int64Value(v)); } break; case FieldType::Double: { double v = 6.0; - EXPECT_CALL(*protocol_, writeDouble(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeDouble(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->doubleValue(v)); } break; case FieldType::String: { std::string v = "seven"; - EXPECT_CALL(*protocol_, writeString(_, v)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeString(_, v)); + } EXPECT_EQ(FilterStatus::Continue, router_->stringValue(v)); } break; default: @@ -292,9 +344,87 @@ class ThriftRouterTestBase { } } + void sendTrivialMap() { + FieldType container_type = FieldType::I32; + uint32_t size = 2; + + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeMapBegin(_, container_type, container_type, size)); + } + EXPECT_EQ(FilterStatus::Continue, router_->mapBegin(container_type, container_type, size)); + + for (int i = 0; i < 2; i++) { + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt32(_, i)); + } + EXPECT_EQ(FilterStatus::Continue, router_->int32Value(i)); + + int j = i + 100; + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt32(_, j)); + } + EXPECT_EQ(FilterStatus::Continue, router_->int32Value(j)); + } + + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeMapEnd(_)); + } + EXPECT_EQ(FilterStatus::Continue, router_->mapEnd()); + } + + void sendTrivialList() { + FieldType container_type = FieldType::I32; + uint32_t size = 3; + + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeListBegin(_, container_type, size)); + } + EXPECT_EQ(FilterStatus::Continue, router_->listBegin(container_type, size)); + + for (int i = 0; i < 3; i++) { + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt32(_, i)); + } + EXPECT_EQ(FilterStatus::Continue, router_->int32Value(i)); + } + + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeListEnd(_)); + } + EXPECT_EQ(FilterStatus::Continue, router_->listEnd()); + } + + void sendTrivialSet() { + FieldType container_type = FieldType::I32; + uint32_t size = 4; + + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeSetBegin(_, container_type, size)); + } + EXPECT_EQ(FilterStatus::Continue, router_->setBegin(container_type, size)); + + for (int i = 0; i < 4; i++) { + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeInt32(_, i)); + } + EXPECT_EQ(FilterStatus::Continue, router_->int32Value(i)); + } + + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeSetEnd(_)); + } + EXPECT_EQ(FilterStatus::Continue, router_->setEnd()); + } + void completeRequest() { - EXPECT_CALL(*protocol_, writeMessageEnd(_)); - EXPECT_CALL(*transport_, encodeFrame(_, _, _)); + for (auto& protocol : all_protocols_) { + EXPECT_CALL(*protocol, writeMessageEnd(_)); + } + + for (auto& transport : all_transports_) { + EXPECT_CALL(*transport, encodeFrame(_, _, _)); + } + EXPECT_CALL(upstream_connection_, write(_, false)); if (msg_type_ == MessageType::Oneway) { @@ -341,20 +471,28 @@ class ThriftRouterTestBase { std::function mock_transport_cb_{}; std::function mock_protocol_cb_{}; + NiceMock dispatcher_; NiceMock context_; + + std::unique_ptr router_; + MockShadowWriter shadow_writer_; + std::shared_ptr shadow_writer_impl_; + NiceMock connection_; - NiceMock dispatcher_; NiceMock time_source_; NiceMock callbacks_; NiceMock* transport_{}; NiceMock* protocol_{}; + std::vector*> all_transports_{}; + std::vector*> all_protocols_{}; + int32_t transports_requested_{}; + int32_t protocols_requested_{}; NiceMock* route_{}; NiceMock route_entry_; NiceMock* host_{}; Tcp::ConnectionPool::ConnectionStatePtr conn_state_; RouteConstSharedPtr route_ptr_; - std::unique_ptr router_; std::string cluster_name_{"cluster"}; @@ -1393,6 +1531,70 @@ TEST_F(ThriftRouterTest, RequestResponseSize) { destroyRouter(); } +TEST_F(ThriftRouterTest, ShadowRequests) { + struct ShadowClusterInfo { + NiceMock cluster; + NiceMock connection; + Tcp::ConnectionPool::ConnectionStatePtr conn_state; + }; + using ShadowClusterInfoPtr = std::shared_ptr; + absl::flat_hash_map shadow_clusters; + + shadow_clusters.try_emplace("shadow_cluster_1", std::make_shared()); + shadow_clusters.try_emplace("shadow_cluster_2", std::make_shared()); + + for (auto& [name, shadow_cluster_info] : shadow_clusters) { + auto& shadow_cluster = shadow_cluster_info->cluster; + auto& upstream_connection = shadow_cluster_info->connection; + auto& conn_state = shadow_cluster_info->conn_state; + + ON_CALL(context_.cluster_manager_, getThreadLocalCluster(absl::string_view(name))) + .WillByDefault(Return(&shadow_cluster)); + EXPECT_CALL(shadow_cluster.tcp_conn_pool_, newConnection(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* { + shadow_cluster.tcp_conn_pool_.newConnectionImpl(cb); + shadow_cluster.tcp_conn_pool_.poolReady(upstream_connection); + return nullptr; + })); + EXPECT_CALL(upstream_connection, close(_)); + + EXPECT_CALL(*shadow_cluster.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + EXPECT_CALL(*shadow_cluster.tcp_conn_pool_.connection_data_, setConnectionState_(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { conn_state.swap(cs); })); + + // Set up policies. + envoy::type::v3::FractionalPercent default_value; + auto policy = std::make_shared(name, "", default_value); + route_entry_.policies_.push_back(policy); + } + + initializeRouter(true); + + // Set sequence id to 0, since that's what the new connections used for shadow requests will use. + startRequestWithExistingConnection(MessageType::Call, 0); + + std::vector field_types = {FieldType::Bool, FieldType::Byte, FieldType::I16, + FieldType::I32, FieldType::I64, FieldType::Double, + FieldType::String}; + for (const auto& field_type : field_types) { + sendTrivialStruct(field_type); + } + + sendTrivialMap(); + sendTrivialList(); + sendTrivialSet(); + + completeRequest(); + returnResponse(); + destroyRouter(); + + shadow_writer_impl_ = nullptr; +} + } // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/test/extensions/filters/network/thrift_proxy/shadow_writer_test.cc b/test/extensions/filters/network/thrift_proxy/shadow_writer_test.cc new file mode 100644 index 0000000000000..fa3bb9ebbf739 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/shadow_writer_test.cc @@ -0,0 +1,467 @@ +#include + +#include "envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.pb.h" +#include "envoy/tcp/conn_pool.h" + +#include "source/common/buffer/buffer_impl.h" +#include "source/extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/factory_context.h" +#include "test/mocks/upstream/host.h" +#include "test/test_common/printers.h" +#include "test/test_common/registry.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::Return; +using testing::ReturnRef; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +struct MockNullResponseDecoder : public NullResponseDecoder { + MockNullResponseDecoder(Transport& transport, Protocol& protocol) + : NullResponseDecoder(transport, protocol) {} + + MOCK_METHOD(ThriftFilters::ResponseStatus, upstreamData, (Buffer::Instance & data), ()); +}; + +class ShadowWriterTest : public testing::Test { +public: + ShadowWriterTest() { + shadow_writer_ = std::make_shared(cm_, "test", context_.scope(), dispatcher_); + metadata_ = std::make_shared(); + metadata_->setMethodName("ping"); + metadata_->setMessageType(MessageType::Call); + metadata_->setSequenceId(1); + + host_ = std::make_shared>(); + } + + void testPoolReady(bool oneway = false) { + NiceMock connection; + + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(&cluster_)); + EXPECT_CALL(*cluster_.cluster_.info_, maintenanceMode()).WillOnce(Return(false)); + EXPECT_CALL(cluster_, tcpConnPool(_, _)) + .WillOnce(Return(Upstream::TcpPoolData([]() {}, &conn_pool_))); + EXPECT_CALL(conn_pool_, newConnection(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::Callbacks& callbacks) -> Tcp::ConnectionPool::Cancellable* { + auto data = + std::make_unique>(); + EXPECT_CALL(*data, connectionState()) + .WillRepeatedly(Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { + return conn_state_.get(); + })); + EXPECT_CALL(*data, setConnectionState_(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { + conn_state_.swap(cs); + })); + EXPECT_CALL(*data, connection()).WillRepeatedly(ReturnRef(connection)); + callbacks.onPoolReady(std::move(data), host_); + return nullptr; + })); + + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_NE(absl::nullopt, router_handle); + EXPECT_CALL(connection, write(_, false)); + + auto& request_owner = router_handle.value().get().requestOwner(); + runRequestMethods(request_owner); + + // The following is a no-op, since no callbacks are pending. + request_owner.continueDecoding(); + + if (!oneway) { + EXPECT_CALL(connection, close(_)); + } + + shadow_writer_ = nullptr; + + const std::string counter_name = + oneway ? "thrift.upstream_rq_oneway" : "thrift.upstream_rq_call"; + EXPECT_EQ(1UL, cluster_.cluster_.info_->statsScope().counterFromString(counter_name).value()); + } + + void testOnUpstreamData(MessageType message_type = MessageType::Reply, bool success = true, + bool on_data_throw_app_exception = false, + bool on_data_throw_regular_exception = false, + bool close_before_response = false) { + NiceMock connection; + + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(&cluster_)); + EXPECT_CALL(*cluster_.cluster_.info_, maintenanceMode()).WillOnce(Return(false)); + EXPECT_CALL(cluster_, tcpConnPool(_, _)) + .WillOnce(Return(Upstream::TcpPoolData([]() {}, &conn_pool_))); + EXPECT_CALL(conn_pool_, newConnection(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::Callbacks& callbacks) -> Tcp::ConnectionPool::Cancellable* { + auto data = + std::make_unique>(); + EXPECT_CALL(*data, connectionState()) + .WillRepeatedly(Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { + return conn_state_.get(); + })); + EXPECT_CALL(*data, setConnectionState_(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { + conn_state_.swap(cs); + })); + + EXPECT_CALL(*data, connection()).WillRepeatedly(ReturnRef(connection)); + callbacks.onPoolReady(std::move(data), host_); + return nullptr; + })); + + ShadowRouterImpl shadow_router(*shadow_writer_, "shadow_cluster", metadata_, + TransportType::Framed, ProtocolType::Binary); + EXPECT_TRUE(shadow_router.createUpstreamRequest()); + + // Exercise methods are no-ops by design. + shadow_router.resetDownstreamConnection(); + shadow_router.onAboveWriteBufferHighWatermark(); + shadow_router.onBelowWriteBufferLowWatermark(); + shadow_router.downstreamConnection(); + shadow_router.metadataMatchCriteria(); + + EXPECT_CALL(connection, write(_, false)); + shadow_router.messageEnd(); + + if (close_before_response) { + shadow_router.onEvent(Network::ConnectionEvent::LocalClose); + return; + } + + // Prepare response metadata & data processing. + MessageMetadataSharedPtr response_metadata = std::make_shared(); + response_metadata->setMessageType(message_type); + response_metadata->setSequenceId(1); + + auto transport_ptr = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + auto protocol_ptr = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + auto decoder_ptr = std::make_unique(*transport_ptr, *protocol_ptr); + decoder_ptr->messageBegin(response_metadata); + decoder_ptr->success_ = success; + + if (on_data_throw_regular_exception || on_data_throw_app_exception) { + EXPECT_CALL(connection, close(_)); + EXPECT_CALL(*decoder_ptr, upstreamData(_)) + .WillOnce(Return(ThriftFilters::ResponseStatus::Reset)); + } else { + EXPECT_CALL(*decoder_ptr, upstreamData(_)) + .WillOnce(Return(ThriftFilters::ResponseStatus::Complete)); + } + + shadow_router.upstream_response_callbacks_ = + std::make_unique(*decoder_ptr); + + Buffer::OwnedImpl response_buffer; + shadow_router.onUpstreamData(response_buffer, false); + + if (on_data_throw_regular_exception || on_data_throw_app_exception) { + return; + } + + // Check stats. + switch (message_type) { + case MessageType::Reply: + EXPECT_EQ(1UL, cluster_.cluster_.info_->statsScope() + .counterFromString("thrift.upstream_resp_reply") + .value()); + if (success) { + EXPECT_EQ(1UL, cluster_.cluster_.info_->statsScope() + .counterFromString("thrift.upstream_resp_success") + .value()); + } else { + EXPECT_EQ(1UL, cluster_.cluster_.info_->statsScope() + .counterFromString("thrift.upstream_resp_error") + .value()); + } + break; + case MessageType::Exception: + EXPECT_EQ(1UL, cluster_.cluster_.info_->statsScope() + .counterFromString("thrift.upstream_resp_exception") + .value()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + void runRequestMethods(RequestOwner& request_owner) { + Buffer::OwnedImpl passthrough_data; + FieldType field_type; + FieldType key_type; + FieldType value_type; + int16_t field_id = 0; + bool bool_value = false; + uint8_t byte_value = 0; + int16_t int16_value = 0; + int32_t int32_value = 0; + int64_t int64_value = 0; + double double_value = 0.0; + uint32_t container_size = 1; + + EXPECT_EQ(FilterStatus::Continue, request_owner.transportBegin(nullptr)); + EXPECT_EQ(FilterStatus::Continue, request_owner.passthroughData(passthrough_data)); + EXPECT_EQ(FilterStatus::Continue, request_owner.structBegin("")); + EXPECT_EQ(FilterStatus::Continue, request_owner.fieldBegin("", field_type, field_id)); + EXPECT_EQ(FilterStatus::Continue, request_owner.fieldEnd()); + EXPECT_EQ(FilterStatus::Continue, request_owner.structEnd()); + EXPECT_EQ(FilterStatus::Continue, request_owner.boolValue(bool_value)); + EXPECT_EQ(FilterStatus::Continue, request_owner.byteValue(byte_value)); + EXPECT_EQ(FilterStatus::Continue, request_owner.int16Value(int16_value)); + EXPECT_EQ(FilterStatus::Continue, request_owner.int32Value(int32_value)); + EXPECT_EQ(FilterStatus::Continue, request_owner.int64Value(int64_value)); + EXPECT_EQ(FilterStatus::Continue, request_owner.doubleValue(double_value)); + EXPECT_EQ(FilterStatus::Continue, request_owner.stringValue("")); + EXPECT_EQ(FilterStatus::Continue, request_owner.mapBegin(key_type, value_type, container_size)); + EXPECT_EQ(FilterStatus::Continue, request_owner.mapEnd()); + EXPECT_EQ(FilterStatus::Continue, request_owner.listBegin(field_type, container_size)); + EXPECT_EQ(FilterStatus::Continue, request_owner.listEnd()); + EXPECT_EQ(FilterStatus::Continue, request_owner.setBegin(field_type, container_size)); + EXPECT_EQ(FilterStatus::Continue, request_owner.setEnd()); + EXPECT_EQ(FilterStatus::Continue, request_owner.messageEnd()); + EXPECT_EQ(FilterStatus::Continue, request_owner.transportEnd()); + } + + NiceMock cluster_; + Tcp::ConnectionPool::ConnectionStatePtr conn_state_; + NiceMock cm_; + NiceMock context_; + NiceMock dispatcher_; + Envoy::ConnectionPool::MockCancellable cancellable_; + MessageMetadataSharedPtr metadata_; + NiceMock conn_pool_; + std::shared_ptr> host_; + std::shared_ptr shadow_writer_; +}; + +TEST_F(ShadowWriterTest, SubmitClusterNotFound) { + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(nullptr)); + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_EQ(absl::nullopt, router_handle); +} + +TEST_F(ShadowWriterTest, SubmitClusterInMaintenance) { + std::shared_ptr cluster = + std::make_shared>(); + EXPECT_CALL(*cluster->cluster_.info_, maintenanceMode()).WillOnce(Return(true)); + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(cluster.get())); + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_EQ(absl::nullopt, router_handle); +} + +TEST_F(ShadowWriterTest, SubmitNoHealthyUpstream) { + metadata_->setMessageType(MessageType::Oneway); + + std::shared_ptr cluster = + std::make_shared>(); + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(cluster.get())); + EXPECT_CALL(*cluster->cluster_.info_, maintenanceMode()).WillOnce(Return(false)); + EXPECT_CALL(*cluster, tcpConnPool(_, _)).WillOnce(Return(absl::nullopt)); + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_EQ(absl::nullopt, router_handle); + + // We still count the request, even if it didn't go through. + EXPECT_EQ( + 1UL, + cluster->cluster_.info_->statsScope().counterFromString("thrift.upstream_rq_oneway").value()); +} + +TEST_F(ShadowWriterTest, SubmitConnectionNotReady) { + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(&cluster_)); + EXPECT_CALL(*cluster_.cluster_.info_, maintenanceMode()).WillOnce(Return(false)); + EXPECT_CALL(cluster_, tcpConnPool(_, _)) + .WillOnce(Return(Upstream::TcpPoolData([]() {}, &conn_pool_))); + EXPECT_CALL(cancellable_, cancel(_)); + EXPECT_CALL(conn_pool_, newConnection(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::Callbacks&) -> Tcp::ConnectionPool::Cancellable* { + return &cancellable_; + })); + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_NE(absl::nullopt, router_handle); + EXPECT_TRUE(router_handle.value().get().waitingForConnection()); + + EXPECT_EQ( + 1UL, + cluster_.cluster_.info_->statsScope().counterFromString("thrift.upstream_rq_call").value()); +} + +TEST_F(ShadowWriterTest, ShadowRequestPoolReady) { testPoolReady(); } + +TEST_F(ShadowWriterTest, ShadowRequestPoolReadyOneWay) { + metadata_->setMessageType(MessageType::Oneway); + testPoolReady(true); +} + +TEST_F(ShadowWriterTest, ShadowRequestWriteBeforePoolReady) { + Tcp::ConnectionPool::Callbacks* callbacks; + + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(&cluster_)); + EXPECT_CALL(*cluster_.cluster_.info_, maintenanceMode()).WillOnce(Return(false)); + EXPECT_CALL(cluster_, tcpConnPool(_, _)) + .WillOnce(Return(Upstream::TcpPoolData([]() {}, &conn_pool_))); + EXPECT_CALL(conn_pool_, newConnection(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* { + callbacks = &cb; + return &cancellable_; + })); + + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_NE(absl::nullopt, router_handle); + + // Write before connection is ready. + auto& request_owner = router_handle.value().get().requestOwner(); + runRequestMethods(request_owner); + + NiceMock connection; + auto data = std::make_unique>(); + EXPECT_CALL(*data, connection()).WillRepeatedly(ReturnRef(connection)); + EXPECT_CALL(*data, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state_.get(); })); + EXPECT_CALL(*data, setConnectionState_(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { conn_state_.swap(cs); })); + + EXPECT_CALL(connection, write(_, false)); + callbacks->onPoolReady(std::move(data), host_); + + EXPECT_CALL(connection, close(_)); + shadow_writer_ = nullptr; + + EXPECT_EQ( + 1UL, + cluster_.cluster_.info_->statsScope().counterFromString("thrift.upstream_rq_call").value()); +} + +TEST_F(ShadowWriterTest, ShadowRequestPoolFailure) { + EXPECT_CALL(cm_, getThreadLocalCluster(_)).WillOnce(Return(&cluster_)); + EXPECT_CALL(*cluster_.cluster_.info_, maintenanceMode()).WillOnce(Return(false)); + EXPECT_CALL(cluster_, tcpConnPool(_, _)) + .WillOnce(Return(Upstream::TcpPoolData([]() {}, &conn_pool_))); + EXPECT_CALL(conn_pool_, newConnection(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::Callbacks& callbacks) + -> Tcp::ConnectionPool::Cancellable* { + auto data = std::make_unique>(); + EXPECT_CALL(*data, connection()).Times(0); + callbacks.onPoolFailure(ConnectionPool::PoolFailureReason::Overflow, "failure", nullptr); + return nullptr; + })); + + auto router_handle = shadow_writer_->submit("shadow_cluster", metadata_, TransportType::Framed, + ProtocolType::Binary); + EXPECT_NE(absl::nullopt, router_handle); + router_handle.value().get().requestOwner().messageEnd(); +} + +TEST_F(ShadowWriterTest, ShadowRequestOnUpstreamDataReplySuccess) { + testOnUpstreamData(MessageType::Reply, true); +} + +TEST_F(ShadowWriterTest, ShadowRequestOnUpstreamDataReplyError) { + testOnUpstreamData(MessageType::Reply, false); +} + +TEST_F(ShadowWriterTest, ShadowRequestOnUpstreamDataReplyException) { + testOnUpstreamData(MessageType::Reply, false); +} + +TEST_F(ShadowWriterTest, ShadowRequestOnUpstreamDataAppException) { + testOnUpstreamData(MessageType::Reply, false, true, false); +} + +TEST_F(ShadowWriterTest, ShadowRequestOnUpstreamDataRegularException) { + testOnUpstreamData(MessageType::Reply, false, false, true); +} + +TEST_F(ShadowWriterTest, ShadowRequestOnUpstreamRemoteClose) { + testOnUpstreamData(MessageType::Reply, false, false, false, true); +} + +TEST_F(ShadowWriterTest, TestNullResponseDecoder) { + auto transport_ptr = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + auto protocol_ptr = NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + auto decoder_ptr = std::make_unique(*transport_ptr, *protocol_ptr); + + decoder_ptr->newDecoderEventHandler(); + EXPECT_FALSE(decoder_ptr->passthroughEnabled()); + + metadata_->setMessageType(MessageType::Reply); + EXPECT_EQ(FilterStatus::Continue, decoder_ptr->messageBegin(metadata_)); + + Buffer::OwnedImpl buffer; + decoder_ptr->upstreamData(buffer); + + EXPECT_EQ(FilterStatus::Continue, decoder_ptr->messageEnd()); + + // First reply field. + { + FieldType field_type; + int16_t field_id = 0; + EXPECT_EQ(FilterStatus::Continue, decoder_ptr->messageBegin(metadata_)); + EXPECT_EQ(FilterStatus::Continue, decoder_ptr->fieldBegin("", field_type, field_id)); + EXPECT_TRUE(decoder_ptr->responseSuccess()); + } + + EXPECT_EQ(FilterStatus::Continue, decoder_ptr->transportBegin(nullptr)); + EXPECT_EQ(FilterStatus::Continue, decoder_ptr->transportEnd()); +} + +struct MockOnDataNullResponseDecoder : public NullResponseDecoder { + MockOnDataNullResponseDecoder(Transport& transport, Protocol& protocol) + : NullResponseDecoder(transport, protocol) {} + + MOCK_METHOD(bool, onData, (), ()); +}; + +TEST_F(ShadowWriterTest, NullResponseDecoderExceptionHandling) { + auto transport_ptr = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + auto protocol_ptr = NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + auto decoder_ptr = std::make_unique(*transport_ptr, *protocol_ptr); + + { + EXPECT_CALL(*decoder_ptr, onData()).WillOnce(Invoke([&]() -> bool { + throw EnvoyException("exception"); + })); + + Buffer::OwnedImpl buffer; + EXPECT_EQ(ThriftFilters::ResponseStatus::Reset, decoder_ptr->upstreamData(buffer)); + } + + { + EXPECT_CALL(*decoder_ptr, onData()).WillOnce(Invoke([&]() -> bool { + throw AppException(AppExceptionType::InternalError, "exception"); + })); + + Buffer::OwnedImpl buffer; + EXPECT_EQ(ThriftFilters::ResponseStatus::Reset, decoder_ptr->upstreamData(buffer)); + } +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy