diff --git a/api/envoy/api/v2/route/route.proto b/api/envoy/api/v2/route/route.proto index e260cff69a963..a1f58891c8788 100644 --- a/api/envoy/api/v2/route/route.proto +++ b/api/envoy/api/v2/route/route.proto @@ -585,6 +585,27 @@ message RouteAction { // Connection properties hash policy. ConnectionProperties connection_properties = 3; } + + // The flag that shortcircuits the hash computing. This field provides a + // 'fallback' style of configuration: "if a terminal policy doesn't work, + // fallback to rest of the policy list", it saves time when the terminal + // policy works. + // + // If true, and there is already a hash computed, ignore rest of the + // list of hash polices. + // For example, if the following hash methods are configured: + // + // ========= ======== + // specifier terminal + // ========= ======== + // Header A true + // Header B false + // Header C false + // ========= ======== + // + // The generateHash process ends if policy "header A" generates a hash, as + // it's a terminal policy. + bool terminal = 4; } // Specifies a list of hash policies to use for ring hash load balancing. Each @@ -596,7 +617,9 @@ message RouteAction { // hash policies fail to generate a hash, no hash will be produced for // the route. In this case, the behavior is the same as if no hash policies // were specified (i.e. the ring hash load balancer will choose a random - // backend). + // backend). If a hash policy has the "terminal" attribute set to true, and + // there is already a hash generated, the hash is returned immediately, + // ignoring the rest of the hash policy list. repeated HashPolicy hash_policy = 15; // Indicates that a HTTP/1.1 client connection to this particular route is allowed to diff --git a/include/envoy/secret/secret_manager.h b/include/envoy/secret/secret_manager.h index ca02d639643ab..93205243f046e 100644 --- a/include/envoy/secret/secret_manager.h +++ b/include/envoy/secret/secret_manager.h @@ -75,6 +75,22 @@ class SecretManager { virtual TlsCertificateConfigProviderSharedPtr findOrCreateTlsCertificateProvider( const envoy::api::v2::core::ConfigSource& config_source, const std::string& config_name, Server::Configuration::TransportSocketFactoryContext& secret_provider_context) PURE; + + /** + * Finds and returns a dynamic secret provider associated to SDS config. Create + * a new one if such provider does not exist. + * + * @param config_source a protobuf message object containing a SDS config source. + * @param config_name a name that uniquely refers to the SDS config source. + * @param secret_provider_context context that provides components for creating and initializing + * secret provider. + * @return CertificateValidationContextConfigProviderSharedPtr the dynamic certificate validation + * context secret provider. + */ + virtual CertificateValidationContextConfigProviderSharedPtr + findOrCreateCertificateValidationContextProvider( + const envoy::api::v2::core::ConfigSource& config_source, const std::string& config_name, + Server::Configuration::TransportSocketFactoryContext& secret_provider_context) PURE; }; } // namespace Secret diff --git a/source/common/grpc/google_async_client_impl.cc b/source/common/grpc/google_async_client_impl.cc index 7fdc57768d6c5..7a929a8f24f92 100644 --- a/source/common/grpc/google_async_client_impl.cc +++ b/source/common/grpc/google_async_client_impl.cc @@ -205,7 +205,8 @@ void GoogleAsyncStreamImpl::resetStream() { } void GoogleAsyncStreamImpl::writeQueued() { - if (!call_initialized_ || finish_pending_ || write_pending_ || write_pending_queue_.empty()) { + if (!call_initialized_ || finish_pending_ || write_pending_ || write_pending_queue_.empty() || + draining_cq_) { return; } write_pending_ = true; diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc index d5301073fd1ac..a48bf4e3baef0 100644 --- a/source/common/router/config_impl.cc +++ b/source/common/router/config_impl.cc @@ -76,9 +76,20 @@ ShadowPolicyImpl::ShadowPolicyImpl(const envoy::api::v2::route::RouteAction& con runtime_key_ = config.request_mirror_policy().runtime_key(); } -class HeaderHashMethod : public HashPolicyImpl::HashMethod { +class HashMethodImplBase : public HashPolicyImpl::HashMethod { public: - HeaderHashMethod(const std::string& header_name) : header_name_(header_name) {} + HashMethodImplBase(bool terminal) : terminal_(terminal) {} + + bool terminal() const override { return terminal_; } + +private: + const bool terminal_; +}; + +class HeaderHashMethod : public HashMethodImplBase { +public: + HeaderHashMethod(const std::string& header_name, bool terminal) + : HashMethodImplBase(terminal), header_name_(header_name) {} absl::optional evaluate(const Network::Address::Instance*, const Http::HeaderMap& headers, @@ -96,18 +107,17 @@ class HeaderHashMethod : public HashPolicyImpl::HashMethod { const Http::LowerCaseString header_name_; }; -class CookieHashMethod : public HashPolicyImpl::HashMethod { +class CookieHashMethod : public HashMethodImplBase { public: CookieHashMethod(const std::string& key, const std::string& path, - const absl::optional& ttl) - : key_(key), path_(path), ttl_(ttl) {} + const absl::optional& ttl, bool terminal) + : HashMethodImplBase(terminal), key_(key), path_(path), ttl_(ttl) {} absl::optional evaluate(const Network::Address::Instance*, const Http::HeaderMap& headers, const HashPolicy::AddCookieCallback add_cookie) const override { absl::optional hash; std::string value = Http::Utility::parseCookieValue(headers, key_); - if (value.empty() && ttl_.has_value()) { value = add_cookie(key_, path_, ttl_.value()); hash = HashUtil::xxHash64(value); @@ -124,8 +134,10 @@ class CookieHashMethod : public HashPolicyImpl::HashMethod { const absl::optional ttl_; }; -class IpHashMethod : public HashPolicyImpl::HashMethod { +class IpHashMethod : public HashMethodImplBase { public: + IpHashMethod(bool terminal) : HashMethodImplBase(terminal) {} + absl::optional evaluate(const Network::Address::Instance* downstream_addr, const Http::HeaderMap&, const HashPolicy::AddCookieCallback) const override { @@ -153,20 +165,21 @@ HashPolicyImpl::HashPolicyImpl( for (auto& hash_policy : hash_policies) { switch (hash_policy.policy_specifier_case()) { case envoy::api::v2::route::RouteAction::HashPolicy::kHeader: - hash_impls_.emplace_back(new HeaderHashMethod(hash_policy.header().header_name())); + hash_impls_.emplace_back( + new HeaderHashMethod(hash_policy.header().header_name(), hash_policy.terminal())); break; case envoy::api::v2::route::RouteAction::HashPolicy::kCookie: { absl::optional ttl; if (hash_policy.cookie().has_ttl()) { ttl = std::chrono::seconds(hash_policy.cookie().ttl().seconds()); } - hash_impls_.emplace_back( - new CookieHashMethod(hash_policy.cookie().name(), hash_policy.cookie().path(), ttl)); + hash_impls_.emplace_back(new CookieHashMethod( + hash_policy.cookie().name(), hash_policy.cookie().path(), ttl, hash_policy.terminal())); break; } case envoy::api::v2::route::RouteAction::HashPolicy::kConnectionProperties: if (hash_policy.connection_properties().source_ip()) { - hash_impls_.emplace_back(new IpHashMethod()); + hash_impls_.emplace_back(new IpHashMethod(hash_policy.terminal())); } break; default: @@ -190,6 +203,11 @@ HashPolicyImpl::generateHash(const Network::Address::Instance* downstream_addr, const uint64_t old_value = hash ? ((hash.value() << 1) | (hash.value() >> 63)) : 0; hash = old_value ^ new_hash.value(); } + // If the policy is a terminal policy and a hash has been generated, ignore + // the rest of the hash policies. + if (hash_impl->terminal() && hash) { + break; + } } return hash; } @@ -851,7 +869,7 @@ RouteConstSharedPtr VirtualHostImpl::getRouteFromEntries(const Http::HeaderMap& const VirtualHostImpl* RouteMatcher::findVirtualHost(const Http::HeaderMap& headers) const { // Fast path the case where we only have a default virtual host. - if (virtual_hosts_.empty() && default_virtual_host_) { + if (virtual_hosts_.empty() && wildcard_virtual_host_suffixes_.empty() && default_virtual_host_) { return default_virtual_host_.get(); } diff --git a/source/common/router/config_impl.h b/source/common/router/config_impl.h index 238f54b99d51b..3aa34f0a2b925 100644 --- a/source/common/router/config_impl.h +++ b/source/common/router/config_impl.h @@ -234,6 +234,9 @@ class HashPolicyImpl : public HashPolicy { virtual absl::optional evaluate(const Network::Address::Instance* downstream_addr, const Http::HeaderMap& headers, const AddCookieCallback add_cookie) const PURE; + + // If the method is a terminal method, ignore rest of the hash policy chain. + virtual bool terminal() const PURE; }; typedef std::unique_ptr HashMethodPtr; diff --git a/source/common/secret/BUILD b/source/common/secret/BUILD index f45602e4c10c9..d0106e0eaeba5 100644 --- a/source/common/secret/BUILD +++ b/source/common/secret/BUILD @@ -52,6 +52,7 @@ envoy_cc_library( "//source/common/config:resources_lib", "//source/common/config:subscription_factory_lib", "//source/common/protobuf:utility_lib", + "//source/common/ssl:certificate_validation_context_config_impl_lib", "//source/common/ssl:tls_certificate_config_impl_lib", ], ) diff --git a/source/common/secret/sds_api.cc b/source/common/secret/sds_api.cc index f6d3b9f9db984..1678df0828d99 100644 --- a/source/common/secret/sds_api.cc +++ b/source/common/secret/sds_api.cc @@ -7,6 +7,7 @@ #include "common/config/resources.h" #include "common/config/subscription_factory.h" #include "common/protobuf/utility.h" +#include "common/ssl/certificate_validation_context_config_impl.h" #include "common/ssl/tls_certificate_config_impl.h" namespace Envoy { @@ -15,11 +16,11 @@ namespace Secret { SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, Stats::Store& stats, Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, - const envoy::api::v2::core::ConfigSource& sds_config, std::string sds_config_name, - std::function destructor_cb) - : local_info_(local_info), dispatcher_(dispatcher), random_(random), stats_(stats), - cluster_manager_(cluster_manager), sds_config_(sds_config), sds_config_name_(sds_config_name), - secret_hash_(0), clean_up_(destructor_cb) { + const envoy::api::v2::core::ConfigSource& sds_config, + const std::string& sds_config_name, std::function destructor_cb) + : secret_hash_(0), local_info_(local_info), dispatcher_(dispatcher), random_(random), + stats_(stats), cluster_manager_(cluster_manager), sds_config_(sds_config), + sds_config_name_(sds_config_name), clean_up_(destructor_cb) { // TODO(JimmyCYJ): Implement chained_init_manager, so that multiple init_manager // can be chained together to behave as one init_manager. In that way, we let // two listeners which share same SdsApi to register at separate init managers, and @@ -59,15 +60,7 @@ void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) fmt::format("Unexpected SDS secret (expecting {}): {}", sds_config_name_, secret.name())); } - const uint64_t new_hash = MessageUtil::hash(secret); - if (new_hash != secret_hash_ && - secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kTlsCertificate) { - secret_hash_ = new_hash; - tls_certificate_secrets_ = - std::make_unique(secret.tls_certificate()); - - update_callback_manager_.runCallbacks(); - } + updateConfigHelper(secret); runInitializeCallbackIfAny(); } @@ -84,5 +77,30 @@ void SdsApi::runInitializeCallbackIfAny() { } } +void TlsCertificateSdsApi::updateConfigHelper(const envoy::api::v2::auth::Secret& secret) { + const uint64_t new_hash = MessageUtil::hash(secret); + if (new_hash != secret_hash_ && + secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kTlsCertificate) { + secret_hash_ = new_hash; + tls_certificate_secrets_ = + std::make_unique(secret.tls_certificate()); + + update_callback_manager_.runCallbacks(); + } +} + +void CertificateValidationContextSdsApi::updateConfigHelper( + const envoy::api::v2::auth::Secret& secret) { + const uint64_t new_hash = MessageUtil::hash(secret); + if (new_hash != secret_hash_ && + secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kValidationContext) { + secret_hash_ = new_hash; + certificate_validation_context_secrets_ = + std::make_unique(secret.validation_context()); + + update_callback_manager_.runCallbacks(); + } +} + } // namespace Secret } // namespace Envoy diff --git a/source/common/secret/sds_api.h b/source/common/secret/sds_api.h index f9211e3832b28..23f31a4100ea7 100644 --- a/source/common/secret/sds_api.h +++ b/source/common/secret/sds_api.h @@ -24,13 +24,12 @@ namespace Secret { * SDS API implementation that fetches secrets from SDS server via Subscription. */ class SdsApi : public Init::Target, - public TlsCertificateConfigProvider, public Config::SubscriptionCallbacks { public: SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, Stats::Store& stats, Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, - const envoy::api::v2::core::ConfigSource& sds_config, std::string sds_config_name, + const envoy::api::v2::core::ConfigSource& sds_config, const std::string& sds_config_name, std::function destructor_cb); // Init::Target @@ -43,14 +42,11 @@ class SdsApi : public Init::Target, return MessageUtil::anyConvert(resource).name(); } - // SecretProvider - const Ssl::TlsCertificateConfig* secret() const override { - return tls_certificate_secrets_.get(); - } - - Common::CallbackHandle* addUpdateCallback(std::function callback) override { - return update_callback_manager_.add(callback); - } +protected: + // Updates local storage of dynamic secrets and invokes callbacks. + virtual void updateConfigHelper(const envoy::api::v2::auth::Secret&) PURE; + uint64_t secret_hash_; + Common::CallbackManager<> update_callback_manager_; private: void runInitializeCallbackIfAny(); @@ -66,13 +62,71 @@ class SdsApi : public Init::Target, std::function initialize_callback_; const std::string sds_config_name_; - uint64_t secret_hash_; Cleanup clean_up_; +}; + +typedef std::shared_ptr SdsApiSharedPtr; + +/** + * TlsCertificateSdsApi implementation maintains and updates dynamic TLS certificate secrets. + */ +class TlsCertificateSdsApi : public SdsApi, public TlsCertificateConfigProvider { +public: + TlsCertificateSdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, + Runtime::RandomGenerator& random, Stats::Store& stats, + Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, + const envoy::api::v2::core::ConfigSource& sds_config, + std::string sds_config_name, std::function destructor_cb) + : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, + sds_config_name, destructor_cb) {} + + // SecretProvider + const Ssl::TlsCertificateConfig* secret() const override { + return tls_certificate_secrets_.get(); + } + Common::CallbackHandle* addUpdateCallback(std::function callback) override { + return update_callback_manager_.add(callback); + } + +private: + // SdsApi + void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; + Ssl::TlsCertificateConfigPtr tls_certificate_secrets_; - Common::CallbackManager<> update_callback_manager_; }; -typedef std::unique_ptr SdsApiPtr; +/** + * CertificateValidationContextSdsApi implementation maintains and updates dynamic certificate + * validation context secrets. + */ +class CertificateValidationContextSdsApi : public SdsApi, + public CertificateValidationContextConfigProvider { +public: + CertificateValidationContextSdsApi(const LocalInfo::LocalInfo& local_info, + Event::Dispatcher& dispatcher, + Runtime::RandomGenerator& random, Stats::Store& stats, + Upstream::ClusterManager& cluster_manager, + Init::Manager& init_manager, + const envoy::api::v2::core::ConfigSource& sds_config, + std::string sds_config_name, + std::function destructor_cb) + : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, + sds_config_name, destructor_cb) {} + + // SecretProvider + const Ssl::CertificateValidationContextConfig* secret() const override { + return certificate_validation_context_secrets_.get(); + } + Common::CallbackHandle* addUpdateCallback(std::function callback) override { + return update_callback_manager_.add(callback); + } + +private: + // SdsApi + void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; + + Ssl::CertificateValidationContextConfigPtr certificate_validation_context_secrets_; +}; } // namespace Secret } // namespace Envoy diff --git a/source/common/secret/secret_manager_impl.cc b/source/common/secret/secret_manager_impl.cc index a310d59777ac9..43fe82126fc06 100644 --- a/source/common/secret/secret_manager_impl.cc +++ b/source/common/secret/secret_manager_impl.cc @@ -65,7 +65,7 @@ SecretManagerImpl::createInlineCertificateValidationContextProvider( } void SecretManagerImpl::removeDynamicSecretProvider(const std::string& map_key) { - ENVOY_LOG(debug, "Unregister secret provider. hash key: {}", map_key); + ENVOY_LOG(debug, "Unregister tls certificate provider. hash key: {}", map_key); auto num_deleted = dynamic_secret_providers_.erase(map_key); ASSERT(num_deleted == 1, ""); @@ -76,7 +76,7 @@ TlsCertificateConfigProviderSharedPtr SecretManagerImpl::findOrCreateTlsCertific Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { const std::string map_key = sds_config_source.SerializeAsString() + config_name; - TlsCertificateConfigProviderSharedPtr secret_provider = dynamic_secret_providers_[map_key].lock(); + SdsApiSharedPtr secret_provider = dynamic_secret_providers_[map_key].lock(); if (!secret_provider) { ASSERT(secret_provider_context.initManager() != nullptr); @@ -86,7 +86,7 @@ TlsCertificateConfigProviderSharedPtr SecretManagerImpl::findOrCreateTlsCertific removeDynamicSecretProvider(map_key); }; - secret_provider = std::make_shared( + secret_provider = std::make_shared( secret_provider_context.localInfo(), secret_provider_context.dispatcher(), secret_provider_context.random(), secret_provider_context.stats(), secret_provider_context.clusterManager(), *secret_provider_context.initManager(), @@ -94,7 +94,34 @@ TlsCertificateConfigProviderSharedPtr SecretManagerImpl::findOrCreateTlsCertific dynamic_secret_providers_[map_key] = secret_provider; } - return secret_provider; + return std::dynamic_pointer_cast(secret_provider); +} + +CertificateValidationContextConfigProviderSharedPtr +SecretManagerImpl::findOrCreateCertificateValidationContextProvider( + const envoy::api::v2::core::ConfigSource& sds_config_source, const std::string& config_name, + Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { + const std::string map_key = sds_config_source.SerializeAsString() + config_name; + + SdsApiSharedPtr secret_provider = dynamic_secret_providers_[map_key].lock(); + if (!secret_provider) { + ASSERT(secret_provider_context.initManager() != nullptr); + + // SdsApi is owned by ListenerImpl and ClusterInfo which are destroyed before + // SecretManagerImpl. It is safe to invoke this callback at the destructor of SdsApi. + std::function unregister_secret_provider = [map_key, this]() { + removeDynamicSecretProvider(map_key); + }; + + secret_provider = std::make_shared( + secret_provider_context.localInfo(), secret_provider_context.dispatcher(), + secret_provider_context.random(), secret_provider_context.stats(), + secret_provider_context.clusterManager(), *secret_provider_context.initManager(), + sds_config_source, config_name, unregister_secret_provider); + dynamic_secret_providers_[map_key] = secret_provider; + } + + return std::dynamic_pointer_cast(secret_provider); } } // namespace Secret diff --git a/source/common/secret/secret_manager_impl.h b/source/common/secret/secret_manager_impl.h index a6017ff8719c3..20e2ac261b33c 100644 --- a/source/common/secret/secret_manager_impl.h +++ b/source/common/secret/secret_manager_impl.h @@ -9,6 +9,7 @@ #include "envoy/ssl/tls_certificate_config.h" #include "common/common/logger.h" +#include "common/secret/sds_api.h" namespace Envoy { namespace Secret { @@ -35,6 +36,11 @@ class SecretManagerImpl : public SecretManager, Logger::Loggable> - dynamic_secret_providers_; + std::unordered_map> dynamic_secret_providers_; }; } // namespace Secret diff --git a/source/common/ssl/BUILD b/source/common/ssl/BUILD index f2e17e9726c54..9cc5179d0e998 100644 --- a/source/common/ssl/BUILD +++ b/source/common/ssl/BUILD @@ -12,7 +12,10 @@ envoy_cc_library( name = "ssl_socket_lib", srcs = ["ssl_socket.cc"], hdrs = ["ssl_socket.h"], - external_deps = ["ssl"], + external_deps = [ + "abseil_synchronization", + "ssl", + ], deps = [ ":context_config_lib", ":context_lib", @@ -23,6 +26,7 @@ envoy_cc_library( "//source/common/common:assert_lib", "//source/common/common:empty_string", "//source/common/common:minimal_logger_lib", + "//source/common/common:thread_annotations", "//source/common/http:headers_lib", ], ) diff --git a/source/common/ssl/context_config_impl.cc b/source/common/ssl/context_config_impl.cc index aef7c1bdeec19..e6ecfa64189f4 100644 --- a/source/common/ssl/context_config_impl.cc +++ b/source/common/ssl/context_config_impl.cc @@ -65,6 +65,9 @@ getCertificateValidationContextConfigProvider( sds_secret_config.name())); } return secret_provider; + } else { + return factory_context.secretManager().findOrCreateCertificateValidationContextProvider( + sds_secret_config.sds_config(), sds_secret_config.name(), factory_context); } } return nullptr; @@ -98,17 +101,21 @@ ContextConfigImpl::ContextConfigImpl( ecdh_curves_(StringUtil::nonEmptyStringOrDefault( RepeatedPtrUtil::join(config.tls_params().ecdh_curves(), ":"), DEFAULT_ECDH_CURVES)), tls_certficate_provider_(getTlsCertificateConfigProvider(config, factory_context)), - secret_update_callback_handle_(nullptr), + tls_certificate_update_callback_handle_(nullptr), certficate_validation_context_provider_( getCertificateValidationContextConfigProvider(config, factory_context)), + certificate_validation_context_update_callback_handle_(nullptr), min_protocol_version_( tlsVersionFromProto(config.tls_params().tls_minimum_protocol_version(), TLS1_VERSION)), max_protocol_version_(tlsVersionFromProto(config.tls_params().tls_maximum_protocol_version(), TLS1_2_VERSION)) {} ContextConfigImpl::~ContextConfigImpl() { - if (secret_update_callback_handle_) { - secret_update_callback_handle_->remove(); + if (tls_certificate_update_callback_handle_) { + tls_certificate_update_callback_handle_->remove(); + } + if (certificate_validation_context_update_callback_handle_) { + certificate_validation_context_update_callback_handle_->remove(); } } diff --git a/source/common/ssl/context_config_impl.h b/source/common/ssl/context_config_impl.h index 7e6443be9a090..90e048cc8495e 100644 --- a/source/common/ssl/context_config_impl.h +++ b/source/common/ssl/context_config_impl.h @@ -39,16 +39,28 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { bool isReady() const override { // Either tls_certficate_provider_ is nullptr or - // tls_certficate_provider_->secret() is NOT nullptr. - return !tls_certficate_provider_ || tls_certficate_provider_->secret() != nullptr; + // tls_certficate_provider_->secret() is NOT nullptr and + // either certficate_validation_context_provider_ is nullptr or + // certficate_validation_context_provider_->secret() is NOT nullptr. + return (!tls_certficate_provider_ || tls_certficate_provider_->secret() != nullptr) && + (!certficate_validation_context_provider_ || + certficate_validation_context_provider_->secret() != nullptr); } void setSecretUpdateCallback(std::function callback) override { if (tls_certficate_provider_) { - if (secret_update_callback_handle_) { - secret_update_callback_handle_->remove(); + if (tls_certificate_update_callback_handle_) { + tls_certificate_update_callback_handle_->remove(); } - secret_update_callback_handle_ = tls_certficate_provider_->addUpdateCallback(callback); + tls_certificate_update_callback_handle_ = + tls_certficate_provider_->addUpdateCallback(callback); + } + if (certficate_validation_context_provider_) { + if (certificate_validation_context_update_callback_handle_) { + certificate_validation_context_update_callback_handle_->remove(); + } + certificate_validation_context_update_callback_handle_ = + certficate_validation_context_provider_->addUpdateCallback(callback); } } @@ -69,9 +81,10 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { const std::string cipher_suites_; const std::string ecdh_curves_; Secret::TlsCertificateConfigProviderSharedPtr tls_certficate_provider_; - Common::CallbackHandle* secret_update_callback_handle_; + Common::CallbackHandle* tls_certificate_update_callback_handle_; Secret::CertificateValidationContextConfigProviderSharedPtr certficate_validation_context_provider_; + Common::CallbackHandle* certificate_validation_context_update_callback_handle_; const unsigned min_protocol_version_; const unsigned max_protocol_version_; }; diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index 5ba3ac9eb7dde..757c0d9a87fe3 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -422,7 +422,11 @@ Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() cons // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. - auto ssl_ctx = ssl_ctx_; + ClientContextSharedPtr ssl_ctx; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + ssl_ctx = ssl_ctx_; + } if (ssl_ctx) { return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Client); } else { @@ -436,7 +440,10 @@ bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } void ClientSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); - ssl_ctx_ = manager_.createSslClientContext(stats_scope_, *config_); + { + absl::WriterMutexLock l(&ssl_ctx_mu_); + ssl_ctx_ = manager_.createSslClientContext(stats_scope_, *config_); + } stats_.ssl_context_update_by_sds_.inc(); } @@ -454,7 +461,11 @@ Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() cons // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. - auto ssl_ctx = ssl_ctx_; + ServerContextSharedPtr ssl_ctx; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + ssl_ctx = ssl_ctx_; + } if (ssl_ctx) { return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Server); } else { @@ -468,7 +479,10 @@ bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } void ServerSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); - ssl_ctx_ = manager_.createSslServerContext(stats_scope_, *config_, server_names_); + { + absl::WriterMutexLock l(&ssl_ctx_mu_); + ssl_ctx_ = manager_.createSslServerContext(stats_scope_, *config_, server_names_); + } stats_.ssl_context_update_by_sds_.inc(); } diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index 951af874a655f..1cdfce9d1711b 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -12,6 +12,7 @@ #include "common/common/logger.h" #include "common/ssl/context_impl.h" +#include "absl/synchronization/mutex.h" #include "openssl/ssl.h" namespace Envoy { @@ -101,7 +102,8 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, Stats::Scope& stats_scope_; SslSocketFactoryStats stats_; ClientContextConfigPtr config_; - ClientContextSharedPtr ssl_ctx_; + mutable absl::Mutex ssl_ctx_mu_; + ClientContextSharedPtr ssl_ctx_ GUARDED_BY(ssl_ctx_mu_); }; class ServerSslSocketFactory : public Network::TransportSocketFactory, @@ -123,7 +125,8 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, SslSocketFactoryStats stats_; ServerContextConfigPtr config_; const std::vector server_names_; - ServerContextSharedPtr ssl_ctx_; + mutable absl::Mutex ssl_ctx_mu_; + ServerContextSharedPtr ssl_ctx_ GUARDED_BY(ssl_ctx_mu_); }; } // namespace Ssl diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index eb7d4196d9107..e4517f158111e 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -135,13 +135,8 @@ Filter::~Filter() { access_log->log(nullptr, nullptr, nullptr, getRequestInfo()); } - if (upstream_handle_) { - upstream_handle_->cancel(); - } - - if (upstream_conn_data_) { - upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); - } + ASSERT(upstream_handle_ == nullptr); + ASSERT(upstream_conn_data_ == nullptr); } TcpProxyStats Config::SharedConfig::generateStats(Stats::Scope& scope) { @@ -412,17 +407,29 @@ void Filter::onDownstreamEvent(Network::ConnectionEvent event) { if (event == Network::ConnectionEvent::RemoteClose) { upstream_conn_data_->connection().close(Network::ConnectionCloseType::FlushWrite); - if (upstream_conn_data_ != nullptr && - upstream_conn_data_->connection().state() != Network::Connection::State::Closed) { - config_->drainManager().add(config_->sharedConfig(), std::move(upstream_conn_data_), - std::move(upstream_callbacks_), std::move(idle_timer_), - read_callbacks_->upstreamHost()); + // Events raised from the previous line may cause upstream_conn_data_ to be NULL if + // it was able to immediately flush all data. + + if (upstream_conn_data_ != nullptr) { + if (upstream_conn_data_->connection().state() != Network::Connection::State::Closed) { + config_->drainManager().add(config_->sharedConfig(), std::move(upstream_conn_data_), + std::move(upstream_callbacks_), std::move(idle_timer_), + read_callbacks_->upstreamHost()); + } else { + upstream_conn_data_.reset(); + } } } else if (event == Network::ConnectionEvent::LocalClose) { upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); upstream_conn_data_.reset(); disableIdleTimer(); } + } else if (upstream_handle_) { + if (event == Network::ConnectionEvent::LocalClose || + event == Network::ConnectionEvent::RemoteClose) { + upstream_handle_->cancel(); + upstream_handle_ = nullptr; + } } } diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 25c6dfda0dda7..e2d1789976b63 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -105,6 +105,7 @@ envoy_cc_library( external_deps = ["abseil_optional"], deps = [ ":thrift_lib", + "//include/envoy/buffer:buffer_interface", "//source/common/common:macros", ], ) @@ -128,14 +129,18 @@ envoy_cc_library( ], external_deps = ["abseil_optional"], deps = [ + ":conn_state_lib", ":decoder_events_lib", ":metadata_lib", ":thrift_lib", + ":thrift_object_interface", + ":transport_interface", "//include/envoy/buffer:buffer_interface", "//include/envoy/registry", "//source/common/common:assert_lib", "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) @@ -221,6 +226,14 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "conn_state_lib", + hdrs = ["conn_state.h"], + deps = [ + "//include/envoy/tcp:conn_pool_interface", + ], +) + envoy_cc_library( name = "thrift_lib", hdrs = ["thrift.h"], @@ -230,6 +243,27 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "thrift_object_interface", + hdrs = ["thrift_object.h"], + deps = [ + "//include/envoy/buffer:buffer_interface", + ], +) + +envoy_cc_library( + name = "thrift_object_lib", + srcs = ["thrift_object_impl.cc"], + hdrs = ["thrift_object_impl.h"], + deps = [ + ":decoder_lib", + ":thrift_lib", + ":thrift_object_interface", + ":unframed_transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "auto_transport_lib", srcs = [ diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index bba8a035a25bf..359cf7f77ed5c 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -142,10 +142,6 @@ void ConfigImpl::createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& c } } -DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { - return std::make_unique(createTransport(), createProtocol(), callbacks); -} - TransportPtr ConfigImpl::createTransport() { return NamedTransportConfigFactory::getFactory(transport_).createTransport(); } diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 0b32e7964c4ea..71c2d1c580c91 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -76,13 +76,11 @@ class ConfigImpl : public Config, // Config ThriftFilterStats& stats() override { return stats_; } ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } - DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; + TransportPtr createTransport() override; + ProtocolPtr createProtocol() override; Router::Config& routerConfig() override { return *this; } private: - TransportPtr createTransport(); - ProtocolPtr createProtocol(); - Server::Configuration::FactoryContext& context_; const std::string stats_prefix_; ThriftFilterStats stats_; diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index 6b83435351b9d..de34169fc385f 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -16,7 +16,9 @@ namespace NetworkFilters { namespace ThriftProxy { ConnectionManager::ConnectionManager(Config& config) - : config_(config), stats_(config_.stats()), decoder_(config_.createDecoder(*this)) {} + : config_(config), stats_(config_.stats()), transport_(config.createTransport()), + protocol_(config.createProtocol()), + decoder_(std::make_unique(*transport_, *protocol_, *this)) {} ConnectionManager::~ConnectionManager() {} @@ -67,22 +69,14 @@ void ConnectionManager::dispatch() { } void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectResponse& response) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - ProtocolType proto_type = decoder_->protocolType(); - ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(proto_type).createProtocol(); Buffer::OwnedImpl buffer; - response.encode(metadata, *proto, buffer); - - // Same logic as protocol above. - TransportPtr transport = - NamedTransportConfigFactory::getFactory(decoder_->transportType()).createTransport(); + response.encode(metadata, *protocol_, buffer); Buffer::OwnedImpl response_buffer; - metadata.setProtocol(proto_type); - transport->encodeFrame(response_buffer, metadata, buffer); + metadata.setProtocol(protocol_->type()); + transport_->encodeFrame(response_buffer, metadata, buffer); read_callbacks_->connection().write(response_buffer, false); } @@ -230,12 +224,30 @@ FilterStatus ConnectionManager::ActiveRpc::transportEnd() { break; } - return decoder_filter_->transportEnd(); + FilterStatus status = event_handler_->transportEnd(); + + if (metadata_->isProtocolUpgradeMessage()) { + ENVOY_CONN_LOG(error, "thrift: sending protocol upgrade response", + parent_.read_callbacks_->connection()); + sendLocalReply(*parent_.protocol_->upgradeResponse(*upgrade_handler_)); + } + + return status; } FilterStatus ConnectionManager::ActiveRpc::messageBegin(MessageMetadataSharedPtr metadata) { metadata_ = metadata; + if (metadata_->isProtocolUpgradeMessage()) { + ASSERT(parent_.protocol_->supportsUpgrade()); + + ENVOY_CONN_LOG(error, "thrift: decoding protocol upgrade request", + parent_.read_callbacks_->connection()); + upgrade_handler_ = parent_.protocol_->upgradeRequestDecoder(); + ASSERT(upgrade_handler_ != nullptr); + event_handler_ = upgrade_handler_.get(); + } + return event_handler_->messageBegin(metadata); } @@ -282,11 +294,10 @@ void ConnectionManager::ActiveRpc::sendLocalReply(const DirectResponse& response parent_.doDeferredRpcDestroy(*this); } -void ConnectionManager::ActiveRpc::startUpstreamResponse(TransportType transport_type, - ProtocolType protocol_type) { +void ConnectionManager::ActiveRpc::startUpstreamResponse(Transport& transport, Protocol& protocol) { ASSERT(response_decoder_ == nullptr); - response_decoder_ = std::make_unique(*this, transport_type, protocol_type); + response_decoder_ = std::make_unique(*this, transport, protocol); } bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index 5d4f4f9cf1151..6432a51377960 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -31,7 +31,8 @@ class Config { virtual ThriftFilters::FilterChainFactory& filterFactory() PURE; virtual ThriftFilterStats& stats() PURE; - virtual DecoderPtr createDecoder(DecoderCallbacks& callbacks) PURE; + virtual TransportPtr createTransport() PURE; + virtual ProtocolPtr createProtocol() PURE; virtual Router::Config& routerConfig() PURE; }; @@ -74,18 +75,10 @@ class ConnectionManager : public Network::ReadFilter, struct ActiveRpc; struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { - ResponseDecoder(ActiveRpc& parent, TransportType transport_type, ProtocolType protocol_type) - : parent_(parent), - decoder_(std::make_unique( - NamedTransportConfigFactory::getFactory(transport_type).createTransport(), - NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(), *this)), + ResponseDecoder(ActiveRpc& parent, Transport& transport, Protocol& protocol) + : parent_(parent), decoder_(std::make_unique(transport, protocol, *this)), complete_(false), first_reply_field_(false) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - initProtocolConverter( - NamedProtocolConfigFactory::getFactory(parent_.parent_.decoder_->protocolType()) - .createProtocol(), - parent_.response_buffer_); + initProtocolConverter(*parent_.parent_.protocol_, parent_.response_buffer_); } bool onData(Buffer::Instance& data); @@ -149,7 +142,7 @@ class ConnectionManager : public Network::ReadFilter, return parent_.decoder_->protocolType(); } void sendLocalReply(const DirectResponse& response) override; - void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; + void startUpstreamResponse(Transport& transport, Protocol& protocol) override; bool upstreamData(Buffer::Instance& buffer) override; void resetDownstreamConnection() override; @@ -170,6 +163,7 @@ class ConnectionManager : public Network::ReadFilter, uint64_t stream_id_; MessageMetadataSharedPtr metadata_; ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + DecoderEventHandlerSharedPtr upgrade_handler_; ResponseDecoderPtr response_decoder_; absl::optional cached_route_; Buffer::OwnedImpl response_buffer_; @@ -188,6 +182,8 @@ class ConnectionManager : public Network::ReadFilter, Network::ReadFilterCallbacks* read_callbacks_{}; + TransportPtr transport_; + ProtocolPtr protocol_; DecoderPtr decoder_; std::list rpcs_; Buffer::OwnedImpl request_buffer_; diff --git a/source/extensions/filters/network/thrift_proxy/conn_state.h b/source/extensions/filters/network/thrift_proxy/conn_state.h new file mode 100644 index 0000000000000..f5db30d458461 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_state.h @@ -0,0 +1,48 @@ +#pragma once + +#include "envoy/tcp/conn_pool.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ThriftConnectionState tracks thrift-related connection state for pooled connections. + */ +class ThriftConnectionState : public Tcp::ConnectionPool::ConnectionState { +public: + /** + * @return true if this upgrade has been attempted on this connection. + */ + bool upgradeAttempted() const { return upgrade_attempted_; } + /** + * @return true if this connection has been upgraded + */ + bool isUpgraded() const { return upgraded_; } + + /** + * Marks the connection as successfully upgraded. + */ + void markUpgraded() { + upgrade_attempted_ = true; + upgraded_ = true; + } + + /** + * Marks the connection as not upgraded. + */ + void markUpgradeFailed() { + upgrade_attempted_ = true; + upgraded_ = false; + } + +private: + bool upgrade_attempted_{false}; + bool upgraded_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index c6d4b470ea890..bc56bffea0b00 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -334,6 +334,9 @@ ProtocolState DecoderStateMachine::popReturnState() { ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { while (state_ != ProtocolState::Done) { + ENVOY_LOG(trace, "thrift: state {}, {} bytes available", ProtocolStateNameValues::name(state_), + buffer.length()); + DecoderStatus s = handleState(buffer); if (s.next_state_ == ProtocolState::WaitForData) { return ProtocolState::WaitForData; @@ -350,8 +353,8 @@ ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { return state_; } -Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks) - : transport_(std::move(transport)), protocol_(std::move(protocol)), callbacks_(callbacks) {} +Decoder::Decoder(Transport& transport, Protocol& protocol, DecoderCallbacks& callbacks) + : transport_(transport), protocol_(protocol), callbacks_(callbacks) {} void Decoder::complete() { request_.reset(); @@ -377,22 +380,22 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { metadata_ = std::make_shared(); } - if (!transport_->decodeFrameStart(data, *metadata_)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); + if (!transport_.decodeFrameStart(data, *metadata_)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_.name()); buffer_underflow = true; return FilterStatus::Continue; } - ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); + ENVOY_LOG(debug, "thrift: {} transport started", transport_.name()); if (metadata_->hasProtocol()) { - if (protocol_->type() == ProtocolType::Auto) { - protocol_->setType(metadata_->protocol()); - ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_->name(), - protocol_->name()); - } else if (metadata_->protocol() != protocol_->type()) { + if (protocol_.type() == ProtocolType::Auto) { + protocol_.setType(metadata_->protocol()); + ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_.name(), + protocol_.name()); + } else if (metadata_->protocol() != protocol_.type()) { throw EnvoyException(fmt::format("transport reports protocol {}, but configured for {}", ProtocolNames::get().fromType(metadata_->protocol()), - ProtocolNames::get().fromType(protocol_->type()))); + ProtocolNames::get().fromType(protocol_.type()))); } } if (metadata_->hasAppException()) { @@ -406,7 +409,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { request_ = std::make_unique(callbacks_.newDecoderEventHandler()); frame_started_ = true; state_machine_ = - std::make_unique(*protocol_, metadata_, request_->handler_); + std::make_unique(protocol_, metadata_, request_->handler_); if (request_->handler_.transportBegin(metadata_) == FilterStatus::StopIteration) { return FilterStatus::StopIteration; @@ -415,7 +418,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ASSERT(state_machine_ != nullptr); - ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), + ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_.name(), ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); ProtocolState rv = state_machine_->run(data); @@ -431,8 +434,8 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ASSERT(rv == ProtocolState::Done); // Message complete, decode end of frame. - if (!transport_->decodeFrameEnd(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); + if (!transport_.decodeFrameEnd(data)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_.name()); buffer_underflow = true; return FilterStatus::Continue; } @@ -440,7 +443,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { frame_ended_ = true; metadata_.reset(); - ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + ENVOY_LOG(debug, "thrift: {} transport ended", transport_.name()); if (request_->handler_.transportEnd() == FilterStatus::StopIteration) { return FilterStatus::StopIteration; } diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index cda5d75b4a8b7..e2886aedebd60 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -2,7 +2,6 @@ #include "envoy/buffer/buffer.h" -#include "common/buffer/buffer_impl.h" #include "common/common/assert.h" #include "common/common/logger.h" @@ -61,7 +60,7 @@ class ProtocolStateNameValues { * DecoderStateMachine is the Thrift message state machine as described in * source/extensions/filters/network/thrift_proxy/docs. */ -class DecoderStateMachine { +class DecoderStateMachine : public Logger::Loggable { public: DecoderStateMachine(Protocol& proto, MessageMetadataSharedPtr& metadata, DecoderEventHandler& handler) @@ -183,15 +182,15 @@ class DecoderCallbacks { }; /** - * Decoder encapsulates a configured TransportPtr and ProtocolPtr. + * Decoder encapsulates a configured Transport and Protocol and provides the ability to decode + * Thrift messages. */ class Decoder : public Logger::Loggable { public: - Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks); - Decoder(TransportType transport_type, ProtocolType protocol_type, DecoderCallbacks& callbacks); + Decoder(Transport& transport, Protocol& protocol, DecoderCallbacks& callbacks); /** - * Drains data from the given buffer while executing a DecoderStateMachine over the data. + * Drains data from the given buffer while executing a state machine over the data. * * @param data a Buffer containing Thrift protocol data * @param buffer_underflow bool set to true if more data is required to continue decoding @@ -201,8 +200,8 @@ class Decoder : public Logger::Loggable { */ FilterStatus onData(Buffer::Instance& data, bool& buffer_underflow); - TransportType transportType() { return transport_->type(); } - ProtocolType protocolType() { return protocol_->type(); } + TransportType transportType() { return transport_.type(); } + ProtocolType protocolType() { return protocol_.type(); } private: struct ActiveRequest { @@ -214,8 +213,8 @@ class Decoder : public Logger::Loggable { void complete(); - TransportPtr transport_; - ProtocolPtr protocol_; + Transport& transport_; + Protocol& protocol_; DecoderCallbacks& callbacks_; ActiveRequestPtr request_; MessageMetadataSharedPtr metadata_; diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index 304db348d9e29..4183514455b91 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -68,10 +68,10 @@ class DecoderFilterCallbacks { /** * Indicates the start of an upstream response. May only be called once. - * @param transport_type TransportType the upstream is using - * @param protocol_type ProtocolType the upstream is using + * @param transport the transport used by the upstream response + * @param protocol the protocol used by the upstream response */ - virtual void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) PURE; + virtual void startUpstreamResponse(Transport& transport, Protocol& protocol) PURE; /** * Called with upstream response data. diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h index e9498d88eef43..bb659e5afdc47 100644 --- a/source/extensions/filters/network/thrift_proxy/metadata.h +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -4,8 +4,11 @@ #include #include +#include #include +#include "envoy/buffer/buffer.h" + #include "common/common/macros.h" #include "common/http/header_map_impl.h" @@ -62,6 +65,11 @@ class MessageMetadata { AppExceptionType appExceptionType() const { return app_ex_type_.value(); } const std::string& appExceptionMessage() const { return app_ex_msg_.value(); } + bool isProtocolUpgradeMessage() const { return protocol_upgrade_message_; } + void setProtocolUpgradeMessage(bool upgrade_message) { + protocol_upgrade_message_ = upgrade_message; + } + private: absl::optional frame_size_{}; absl::optional proto_{}; @@ -71,6 +79,7 @@ class MessageMetadata { Http::HeaderMapImpl headers_; absl::optional app_ex_type_; absl::optional app_ex_msg_; + bool protocol_upgrade_message_{false}; }; typedef std::shared_ptr MessageMetadataSharedPtr; diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 0e4066021ce9b..316ca0e79df3a 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -11,8 +11,12 @@ #include "common/config/utility.h" #include "common/singleton/const_singleton.h" +#include "extensions/filters/network/thrift_proxy/conn_state.h" +#include "extensions/filters/network/thrift_proxy/decoder_events.h" #include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/thrift.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" +#include "extensions/filters/network/thrift_proxy/transport.h" #include "absl/strings/string_view.h" @@ -21,6 +25,9 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +class DirectResponse; +typedef std::unique_ptr DirectResponsePtr; + /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -394,10 +401,86 @@ class Protocol { * @param value std::string to write */ virtual void writeBinary(Buffer::Instance& buffer, const std::string& value) PURE; + + /** + * Indicates whether a protocol uses start-of-connection messages to negotiate protocol options. + * If this method returns true, the Protocol must invoke setProtocolUpgradeMessage during + * readMessageBegin if it detects an upgrade request. + * + * @return true for protocols that exchange messages at the start of a connection to negotiate + * protocol upgrade (or options) + */ + virtual bool supportsUpgrade() { return false; } + + /** + * Creates an opaque DecoderEventHandlerSharedPtr that can decode a downstream client's upgrade + * request. When the request is complete, the decoder is passed back to writeUpgradeResponse + * to allow the Protocol to update its internal state and generate a response to the request. + * + * @return a DecoderEventHandlerSharedPtr that decodes a downstream client's upgrade request + */ + virtual DecoderEventHandlerSharedPtr upgradeRequestDecoder() { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + + /** + * Writes a response to a downstream client's upgrade request. + * @param decoder DecoderEventHandlerSharedPtr created by upgradeRequestDecoder + * @return DirectResponsePtr containing an upgrade response + */ + virtual DirectResponsePtr upgradeResponse(const DecoderEventHandler& decoder) { + UNREFERENCED_PARAMETER(decoder); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + + /** + * Checks whether a given upstream connection can be upgraded and generates an upgrade request + * message. If this method returns a ThriftObject it will be used to decode the upstream's next + * response. + * + * @param transport the Transport to use for decoding the response + * @param state ThriftConnectionState tracking whether upgrade has already been performed + * @param buffer Buffer::Instance to modify with an upgrade request + * @return a ThriftObject capable of decoding an upgrade response or nullptr if upgrade was + * already completed (successfully or not) + */ + virtual ThriftObjectPtr attemptUpgrade(Transport& transport, ThriftConnectionState& state, + Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(transport); + UNREFERENCED_PARAMETER(state); + UNREFERENCED_PARAMETER(buffer); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + + /** + * Completes an upgrade previously started via attemptUpgrade. + * @param response ThriftObject created by attemptUpgrade, after the response has completed + * decoding + */ + virtual void completeUpgrade(ThriftConnectionState& state, ThriftObject& response) { + UNREFERENCED_PARAMETER(state); + UNREFERENCED_PARAMETER(response); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } }; typedef std::unique_ptr ProtocolPtr; +/** + * A DirectResponse manipulates a Protocol to directly create a Thrift response message. + */ +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param metadata the MessageMetadata for the request that generated this response + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(MessageMetadata& metadata, Protocol& proto, + Buffer::Instance& buffer) const PURE; +}; + /** * Implemented by each Thrift protocol and registered via Registry::registerFactory or the * convenience class RegisterFactory. @@ -444,25 +527,6 @@ template class ProtocolFactoryBase : public NamedProtocolCo const std::string name_; }; -/** - * A DirectResponse manipulates a Protocol to directly create a Thrift response message. - */ -class DirectResponse { -public: - virtual ~DirectResponse() {} - - /** - * Encodes the response via the given Protocol. - * @param metadata the MessageMetadata for the request that generated this response - * @param proto the Protocol to be used for message encoding - * @param buffer the Buffer into which the message should be encoded - */ - virtual void encode(MessageMetadata& metadata, Protocol& proto, - Buffer::Instance& buffer) const PURE; -}; - -typedef std::unique_ptr DirectResponsePtr; - } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index 8a3fde7894505..22a868f1a40af 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -19,8 +19,8 @@ class ProtocolConverter : public virtual DecoderEventHandler { ProtocolConverter() {} virtual ~ProtocolConverter() {} - void initProtocolConverter(ProtocolPtr&& proto, Buffer::Instance& buffer) { - proto_ = std::move(proto); + void initProtocolConverter(Protocol& proto, Buffer::Instance& buffer) { + proto_ = &proto; buffer_ = &buffer; } @@ -125,7 +125,7 @@ class ProtocolConverter : public virtual DecoderEventHandler { ProtocolType protocolType() const { return proto_->type(); } private: - ProtocolPtr proto_; + Protocol* proto_; Buffer::Instance* buffer_{}; }; diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD index e5051335b04e9..ba6cf07cc045b 100644 --- a/source/extensions/filters/network/thrift_proxy/router/BUILD +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -26,7 +26,9 @@ envoy_cc_library( name = "router_interface", hdrs = ["router.h"], external_deps = ["abseil_optional"], - deps = [], + deps = [ + "//source/extensions/filters/network/thrift_proxy:metadata_lib", + ], ) envoy_cc_library( @@ -46,8 +48,9 @@ envoy_cc_library( "//source/extensions/filters/network/thrift_proxy:app_exception_lib", "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", - "//source/extensions/filters/network/thrift_proxy:protocol_lib", - "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:thrift_object_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", "@envoy_api//envoy/config/filter/network/thrift_proxy/v2alpha1:thrift_proxy_cc", ], diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h index 7f995e5e25a1d..b456a24ae7da1 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router.h +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -3,6 +3,8 @@ #include #include +#include "extensions/filters/network/thrift_proxy/metadata.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index 57e8be9f6caf1..a854b7432aa87 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -213,13 +213,12 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { FilterStatus Router::messageEnd() { ProtocolConverter::messageEnd(); - TransportPtr transport = - NamedTransportConfigFactory::getFactory(upstream_request_->transport_type_).createTransport(); Buffer::OwnedImpl transport_buffer; - upstream_request_->metadata_->setProtocol(upstream_request_->protocol_type_); + upstream_request_->metadata_->setProtocol(upstream_request_->protocol_->type()); - transport->encodeFrame(transport_buffer, *upstream_request_->metadata_, upstream_request_buffer_); + upstream_request_->transport_->encodeFrame(transport_buffer, *upstream_request_->metadata_, + upstream_request_buffer_); upstream_request_->conn_data_->connection().write(transport_buffer, false); upstream_request_->onRequestComplete(); return FilterStatus::Continue; @@ -228,16 +227,32 @@ FilterStatus Router::messageEnd() { void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { ASSERT(!upstream_request_->response_complete_); - if (!upstream_request_->response_started_) { - callbacks_->startUpstreamResponse(upstream_request_->transport_type_, - upstream_request_->protocol_type_); - upstream_request_->response_started_ = true; - } + if (upstream_request_->upgrade_response_ != nullptr) { + // Handle upgrade response. + if (!upstream_request_->upgrade_response_->onData(data)) { + // Wait for more data. + return; + } - if (callbacks_->upstreamData(data)) { - upstream_request_->onResponseComplete(); - cleanup(); - return; + upstream_request_->protocol_->completeUpgrade( + *upstream_request_->conn_data_->connectionStateTyped(), + *upstream_request_->upgrade_response_); + + upstream_request_->upgrade_response_.reset(); + upstream_request_->onRequestStart(true); + } else { + // Handle normal response. + if (!upstream_request_->response_started_) { + callbacks_->startUpstreamResponse(*upstream_request_->transport_, + *upstream_request_->protocol_); + upstream_request_->response_started_ = true; + } + + if (callbacks_->upstreamData(data)) { + upstream_request_->onResponseComplete(); + cleanup(); + return; + } } if (end_stream) { @@ -284,9 +299,10 @@ void Router::cleanup() { upstream_request_.reset(); } Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, MessageMetadataSharedPtr& metadata, TransportType transport_type, ProtocolType protocol_type) - : parent_(parent), conn_pool_(pool), metadata_(metadata), transport_type_(transport_type), - protocol_type_(protocol_type), request_complete_(false), response_started_(false), - response_complete_(false) {} + : parent_(parent), conn_pool_(pool), metadata_(metadata), + transport_(NamedTransportConfigFactory::getFactory(transport_type).createTransport()), + protocol_(NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol()), + request_complete_(false), response_started_(false), response_complete_(false) {} Router::UpstreamRequest::~UpstreamRequest() {} @@ -298,6 +314,11 @@ FilterStatus Router::UpstreamRequest::start() { return FilterStatus::StopIteration; } + if (upgrade_response_ != nullptr) { + // Pause while we wait for an upgrade response. + return FilterStatus::StopIteration; + } + return FilterStatus::Continue; } @@ -329,12 +350,28 @@ void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr onUpstreamHostSelected(host); conn_data_ = std::move(conn_data); conn_data_->addUpstreamCallbacks(parent_); - conn_pool_handle_ = nullptr; - parent_.initProtocolConverter( - NamedProtocolConfigFactory::getFactory(protocol_type_).createProtocol(), - parent_.upstream_request_buffer_); + ThriftConnectionState* state = conn_data_->connectionStateTyped(); + if (state == nullptr) { + conn_data_->setConnectionState(std::make_unique()); + state = conn_data_->connectionStateTyped(); + } + + if (protocol_->supportsUpgrade()) { + upgrade_response_ = + protocol_->attemptUpgrade(*transport_, *state, parent_.upstream_request_buffer_); + if (upgrade_response_ != nullptr) { + conn_data_->connection().write(parent_.upstream_request_buffer_, false); + return; + } + } + + onRequestStart(continue_decoding); +} + +void Router::UpstreamRequest::onRequestStart(bool continue_decoding) { + parent_.initProtocolConverter(*protocol_, parent_.upstream_request_buffer_); // TODO(zuercher): need to use an upstream-connection-specific sequence id parent_.convertMessageBegin(metadata_); diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 224b6fb09997c..8d6033bdb410a 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -15,6 +15,7 @@ #include "extensions/filters/network/thrift_proxy/conn_manager.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/router/router.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" #include "absl/types/optional.h" @@ -135,6 +136,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, void onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn, Upstream::HostDescriptionConstSharedPtr host) override; + void onRequestStart(bool continue_decoding); void onRequestComplete(); void onResponseComplete(); void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host); @@ -147,8 +149,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; Tcp::ConnectionPool::ConnectionDataPtr conn_data_; Upstream::HostDescriptionConstSharedPtr upstream_host_; - TransportType transport_type_; - ProtocolType protocol_type_; + TransportPtr transport_; + ProtocolPtr protocol_; + ThriftObjectPtr upgrade_response_; bool request_complete_ : 1; bool response_started_ : 1; diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object.h b/source/extensions/filters/network/thrift_proxy/thrift_object.h new file mode 100644 index 0000000000000..321e6496674cd --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" + +#include "extensions/filters/network/thrift_proxy/thrift.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class ThriftBase; + +/** + * ThriftValue is a field or container (list, set, or map) element. + */ +class ThriftValue { +public: + virtual ~ThriftValue() {} + + /** + * @return FieldType the type of this value + */ + virtual FieldType type() const PURE; + + /** + * @return const T& pointer to the value, provided that it can be cast to the given type + * @throw EnvoyException if the type T does not match the type + */ + template const T& getValueTyped() const { + // Use the Traits template to determine what FieldType the value must have to be cast to T + // and throw if the value's type doesn't match. + FieldType expected_field_type = Traits::getFieldType(); + if (expected_field_type != type()) { + throw EnvoyException(fmt::format("expected field type {}, got {}", + static_cast(expected_field_type), + static_cast(type()))); + } + + return *static_cast(getValue()); + } + +protected: + /** + * @return void* pointing to the underlying value, to be dynamically cast in getValueTyped + */ + virtual const void* getValue() const PURE; + +private: + /** + * Traits allows getValueTyped() to enforce that the field type is being cast to the desired type. + */ + template class Traits { + public: + // Compilation failures where T does not have a member getFieldType typically mean that + // getValueTyped was called with a type T that is not used to encode Thrift values. + // The specializations below encode the valid types for Thrift primitive types. + static FieldType getFieldType() { return T::getFieldType(); } + }; +}; + +// Explicit specializations of ThriftValue::Types for primitive types. +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Bool; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Byte; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I16; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I32; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I64; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Double; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::String; } +}; + +typedef std::unique_ptr ThriftValuePtr; +typedef std::list ThriftValuePtrList; +typedef std::list> ThriftValuePtrPairList; + +/** + * ThriftField is a field within a ThriftStruct. + */ +class ThriftField { +public: + virtual ~ThriftField() {} + + /** + * @return FieldType this field's type + */ + virtual FieldType fieldType() const PURE; + + /** + * @return int16_t the field's identifier + */ + virtual int16_t fieldId() const PURE; + + /** + * @return const ThriftValue& containing the field's value + */ + virtual const ThriftValue& getValue() const PURE; +}; + +typedef std::unique_ptr ThriftFieldPtr; +typedef std::list ThriftFieldPtrList; + +/** + * ThriftListValue is an ordered list of ThriftValues. + */ +class ThriftListValue { +public: + virtual ~ThriftListValue() {} + + /** + * @return const ThriftValuePtrList& containing the ThriftValues that comprise the list + */ + virtual const ThriftValuePtrList& elements() const PURE; + + /** + * @return FieldType of the underlying elements + */ + virtual FieldType elementType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::List; } +}; + +/** + * ThriftSetValue is a set of ThriftValues, maintained in their original order. + */ +class ThriftSetValue { +public: + virtual ~ThriftSetValue() {} + + /** + * @return const ThriftValuePtrList& containing the ThriftValues that comprise the set + */ + virtual const ThriftValuePtrList& elements() const PURE; + + /** + * @return FieldType of the underlying elements + */ + virtual FieldType elementType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Set; } +}; + +/** + * ThriftMapValue is a map of pairs of ThriftValues, maintained in their original order. + */ +class ThriftMapValue { +public: + virtual ~ThriftMapValue() {} + + /** + * @return const ThriftValuePtrPairList& containing the ThriftValue key-value paris that comprise + * the map. + */ + virtual const ThriftValuePtrPairList& elements() const PURE; + + /** + * @return FieldType of the underlying keys + */ + virtual FieldType keyType() const PURE; + + /** + * @return FieldType of the underlying values + */ + virtual FieldType valueType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Map; } +}; + +/** + * ThriftStructValue is a sequence of ThriftFields. + */ +class ThriftStructValue { +public: + virtual ~ThriftStructValue() {} + + /** + * @return const ThriftFieldPtrList& containing the ThriftFields that comprise the struct. + */ + virtual const ThriftFieldPtrList& fields() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Struct; } +}; + +/** + * ThriftObject is a ThrfitStructValue that can be read from a Buffer::Instance. + */ +class ThriftObject : public ThriftStructValue { +public: + virtual ~ThriftObject() {} + + /* + * Consumes bytes from the buffer until a single complete Thrift struct has been consumed. + * @param buffer starting with a Thrift struct + * @return true when a single complete struct has been consumed; false if more data is needed to + * complete decoding + * @throw EnvoyException if the struct is invalid + */ + virtual bool onData(Buffer::Instance& buffer) PURE; +}; + +typedef std::unique_ptr ThriftObjectPtr; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc new file mode 100644 index 0000000000000..fb9271331f3a2 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc @@ -0,0 +1,394 @@ +#include "extensions/filters/network/thrift_proxy/thrift_object_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace { + +std::unique_ptr makeValue(ThriftBase* parent, FieldType type) { + switch (type) { + case FieldType::Stop: + NOT_REACHED_GCOVR_EXCL_LINE; + + case FieldType::List: + return std::make_unique(parent); + + case FieldType::Set: + return std::make_unique(parent); + + case FieldType::Map: + return std::make_unique(parent); + + case FieldType::Struct: + return std::make_unique(parent); + + default: + return std::make_unique(parent, type); + } +} + +} // namespace + +ThriftBase::ThriftBase(ThriftBase* parent) : parent_(parent) {} + +FilterStatus ThriftBase::structBegin(absl::string_view name) { + ASSERT(delegate_ != nullptr); + return delegate_->structBegin(name); +} + +FilterStatus ThriftBase::structEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->structEnd(); +} + +FilterStatus ThriftBase::fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) { + ASSERT(delegate_ != nullptr); + return delegate_->fieldBegin(name, field_type, field_id); +} + +FilterStatus ThriftBase::fieldEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->fieldEnd(); +} + +FilterStatus ThriftBase::boolValue(bool value) { + ASSERT(delegate_ != nullptr); + return delegate_->boolValue(value); +} + +FilterStatus ThriftBase::byteValue(uint8_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->byteValue(value); +} + +FilterStatus ThriftBase::int16Value(int16_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int16Value(value); +} + +FilterStatus ThriftBase::int32Value(int32_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int32Value(value); +} + +FilterStatus ThriftBase::int64Value(int64_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int64Value(value); +} + +FilterStatus ThriftBase::doubleValue(double value) { + ASSERT(delegate_ != nullptr); + return delegate_->doubleValue(value); +} + +FilterStatus ThriftBase::stringValue(absl::string_view value) { + ASSERT(delegate_ != nullptr); + return delegate_->stringValue(value); +} + +FilterStatus ThriftBase::mapBegin(FieldType key_type, FieldType value_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->mapBegin(key_type, value_type, size); +} + +FilterStatus ThriftBase::mapEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->mapEnd(); +} + +FilterStatus ThriftBase::listBegin(FieldType elem_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->listBegin(elem_type, size); +} + +FilterStatus ThriftBase::listEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->listEnd(); +} + +FilterStatus ThriftBase::setBegin(FieldType elem_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->setBegin(elem_type, size); +} + +FilterStatus ThriftBase::setEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->setEnd(); +} + +void ThriftBase::delegateComplete() { + ASSERT(delegate_ != nullptr); + delegate_ = nullptr; +} + +ThriftFieldImpl::ThriftFieldImpl(ThriftStructValueImpl* parent, absl::string_view name, + FieldType field_type, int16_t field_id) + : ThriftBase(parent), name_(name), field_type_(field_type), field_id_(field_id) { + auto value = makeValue(this, field_type_); + delegate_ = value.get(); + value_ = std::move(value); +} + +FilterStatus ThriftFieldImpl::fieldEnd() { + if (delegate_) { + return delegate_->fieldEnd(); + } + + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftListValueImpl::listBegin(FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->listBegin(elem_type, size); + } + + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftListValueImpl::listEnd() { + if (delegate_) { + return delegate_->listEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftListValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + auto elem = makeValue(this, elem_type_); + delegate_ = elem.get(); + elements_.push_back(std::move(elem)); + remaining_--; +} + +FilterStatus ThriftSetValueImpl::setBegin(FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->setBegin(elem_type, size); + } + + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftSetValueImpl::setEnd() { + if (delegate_) { + return delegate_->setEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftSetValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + auto elem = makeValue(this, elem_type_); + delegate_ = elem.get(); + elements_.push_back(std::move(elem)); + remaining_--; +} + +FilterStatus ThriftMapValueImpl::mapBegin(FieldType key_type, FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->mapBegin(key_type, elem_type, size); + } + + key_type_ = key_type; + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftMapValueImpl::mapEnd() { + if (delegate_) { + return delegate_->mapEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftMapValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + // Prepare for first element's key. + if (elements_.empty()) { + auto key = makeValue(this, key_type_); + delegate_ = key.get(); + elements_.emplace_back(std::move(key), nullptr); + return; + } + + // Prepare for any elements's value. + auto& elem = elements_.back(); + if (elem.second == nullptr) { + auto value = makeValue(this, elem_type_); + delegate_ = value.get(); + elem.second = std::move(value); + + remaining_--; + return; + } + + // Key-value pair completed, prepare for next key. + auto key = makeValue(this, key_type_); + delegate_ = key.get(); + elements_.emplace_back(std::move(key), nullptr); +} + +FilterStatus ThriftValueImpl::boolValue(bool value) { + ASSERT(value_type_ == FieldType::Bool); + bool_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::byteValue(uint8_t value) { + ASSERT(value_type_ == FieldType::Byte); + byte_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int16Value(int16_t value) { + ASSERT(value_type_ == FieldType::I16); + int16_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int32Value(int32_t value) { + ASSERT(value_type_ == FieldType::I32); + int32_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int64Value(int64_t value) { + ASSERT(value_type_ == FieldType::I64); + int64_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::doubleValue(double value) { + ASSERT(value_type_ == FieldType::Double); + double_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::stringValue(absl::string_view value) { + ASSERT(value_type_ == FieldType::String); + string_value_ = std::string(value); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +const void* ThriftValueImpl::getValue() const { + switch (value_type_) { + case FieldType::Bool: + return &bool_value_; + case FieldType::Byte: + return &byte_value_; + case FieldType::I16: + return &int16_value_; + case FieldType::I32: + return &int32_value_; + case FieldType::I64: + return &int64_value_; + case FieldType::Double: + return &double_value_; + case FieldType::String: + return &string_value_; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +FilterStatus ThriftStructValueImpl::structBegin(absl::string_view name) { + if (delegate_) { + return delegate_->structBegin(name); + } + + return FilterStatus::Continue; +} + +FilterStatus ThriftStructValueImpl::structEnd() { + if (delegate_) { + return delegate_->structEnd(); + } + + if (parent_) { + parent_->delegateComplete(); + } + + return FilterStatus::Continue; +} + +FilterStatus ThriftStructValueImpl::fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) { + if (delegate_) { + return delegate_->fieldBegin(name, field_type, field_id); + } + + if (field_type != FieldType::Stop) { + auto field = std::make_unique(this, name, field_type, field_id); + delegate_ = field.get(); + fields_.emplace_back(std::move(field)); + } + + return FilterStatus::Continue; +} + +ThriftObjectImpl::ThriftObjectImpl(Transport& transport, Protocol& protocol) + : ThriftStructValueImpl(nullptr), + decoder_(std::make_unique(transport, protocol, *this)) {} + +bool ThriftObjectImpl::onData(Buffer::Instance& buffer) { + bool underflow = false; + auto result = decoder_->onData(buffer, underflow); + ASSERT(result == FilterStatus::Continue); + + if (complete_) { + decoder_.reset(); + } + return complete_; +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h new file mode 100644 index 0000000000000..b9057dfab2bc7 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h @@ -0,0 +1,262 @@ +#pragma once + +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ThriftBase is a base class for decoding Thrift objects. It implements methods from + * DecoderEventHandler to automatically delegate to an underlying ThriftBase so that, for example, + * the fieldBegin call for a struct field nested within a list is automatically delegated down the + * object hierarchy to the correct ThriftBase subclass. + */ +class ThriftBase : public DecoderEventHandler { +public: + ThriftBase(ThriftBase* parent); + ~ThriftBase() {} + + // DecoderEventHandler + FilterStatus transportBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus transportEnd() override { return FilterStatus::Continue; } + FilterStatus messageBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus messageEnd() override { return FilterStatus::Continue; } + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; + FilterStatus fieldEnd() override; + FilterStatus boolValue(bool value) override; + FilterStatus byteValue(uint8_t value) override; + FilterStatus int16Value(int16_t value) override; + FilterStatus int32Value(int32_t value) override; + FilterStatus int64Value(int64_t value) override; + FilterStatus doubleValue(double value) override; + FilterStatus stringValue(absl::string_view value) override; + FilterStatus mapBegin(FieldType key_type, FieldType value_type, uint32_t size) override; + FilterStatus mapEnd() override; + FilterStatus listBegin(FieldType elem_type, uint32_t size) override; + FilterStatus listEnd() override; + FilterStatus setBegin(FieldType elem_type, uint32_t size) override; + FilterStatus setEnd() override; + + // Invoked when the current delegate is complete. Completion implies that the delegate is fully + // specified (all list values processed, all struct fields processed, etc). + virtual void delegateComplete(); + +protected: + ThriftBase* parent_; + ThriftBase* delegate_{nullptr}; +}; + +/** + * ThriftValueBase is a base class for all struct field values, list values, set values, map keys, + * and map values. + */ +class ThriftValueBase : public ThriftValue, public ThriftBase { +public: + ThriftValueBase(ThriftBase* parent, FieldType value_type) + : ThriftBase(parent), value_type_(value_type) {} + ~ThriftValueBase() {} + + // ThriftValue + FieldType type() const override { return value_type_; } + +protected: + const FieldType value_type_; +}; + +class ThriftStructValueImpl; + +/** + * ThriftField represents a field in a thrift Struct. It always delegates DecoderEventHandler + * methods to a subclass of ThriftValueBase. + */ +class ThriftFieldImpl : public ThriftField, public ThriftBase { +public: + ThriftFieldImpl(ThriftStructValueImpl* parent, absl::string_view name, FieldType field_type, + int16_t field_id); + + // DecoderEventHandler + FilterStatus fieldEnd() override; + + // ThriftField + FieldType fieldType() const override { return field_type_; } + int16_t fieldId() const override { return field_id_; } + const ThriftValue& getValue() const override { return *value_; } + +private: + std::string name_; + FieldType field_type_; + int16_t field_id_; + ThriftValuePtr value_; +}; + +/** + * ThriftStructValueImpl implements ThriftStruct. + */ +class ThriftStructValueImpl : public ThriftStructValue, public ThriftValueBase { +public: + ThriftStructValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Struct) {} + + // DecoderEventHandler + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; + + // ThriftStructValue + const ThriftFieldPtrList& fields() const override { return fields_; } + +private: + // ThriftValue + const void* getValue() const override { return this; }; + + ThriftFieldPtrList fields_; +}; + +/** + * ThriftListValueImpl represents Thrift lists. + */ +class ThriftListValueImpl : public ThriftListValue, public ThriftValueBase { +public: + ThriftListValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::List) {} + + // DecoderEventHandler + FilterStatus listBegin(FieldType elem_type, uint32_t size) override; + FilterStatus listEnd() override; + + // ThriftListValue + const ThriftValuePtrList& elements() const override { return elements_; } + FieldType elementType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrList elements_; +}; + +/** + * ThriftSetValueImpl represents Thrift sets. + */ +class ThriftSetValueImpl : public ThriftSetValue, public ThriftValueBase { +public: + ThriftSetValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Set) {} + + // DecoderEventHandler + FilterStatus setBegin(FieldType elem_type, uint32_t size) override; + FilterStatus setEnd() override; + + // ThriftSetValue + const ThriftValuePtrList& elements() const override { return elements_; } + FieldType elementType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrList elements_; // maintain original order +}; + +/** + * ThriftMapValueImpl represents Thrift maps. + */ +class ThriftMapValueImpl : public ThriftMapValue, public ThriftValueBase { +public: + ThriftMapValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Map) {} + + // DecoderEventHandler + FilterStatus mapBegin(FieldType key_type, FieldType elem_type, uint32_t size) override; + FilterStatus mapEnd() override; + + // ThriftMapValue + const ThriftValuePtrPairList& elements() const override { return elements_; } + FieldType keyType() const override { return key_type_; } + FieldType valueType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType key_type_{FieldType::Stop}; + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrPairList elements_; // maintain original order +}; + +/** + * ThriftValueImpl represents primitive Thrift types, including strings. + */ +class ThriftValueImpl : public ThriftValueBase { +public: + ThriftValueImpl(ThriftBase* parent, FieldType value_type) : ThriftValueBase(parent, value_type) {} + + // DecoderEventHandler + FilterStatus boolValue(bool value) override; + FilterStatus byteValue(uint8_t value) override; + FilterStatus int16Value(int16_t value) override; + FilterStatus int32Value(int32_t value) override; + FilterStatus int64Value(int64_t value) override; + FilterStatus doubleValue(double value) override; + FilterStatus stringValue(absl::string_view value) override; + +protected: + // ThriftValue + const void* getValue() const override; + +private: + union { + bool bool_value_; + uint8_t byte_value_; + int16_t int16_value_; + int32_t int32_value_; + int64_t int64_value_; + double double_value_; + }; + std::string string_value_; +}; + +/** + * ThriftObjectImpl is a generic representation of a Thrift struct. + */ +class ThriftObjectImpl : public ThriftObject, + public ThriftStructValueImpl, + public DecoderCallbacks { +public: + ThriftObjectImpl(Transport& transport, Protocol& protocol); + + // DecoderCallbacks + DecoderEventHandler& newDecoderEventHandler() override { return *this; } + FilterStatus transportEnd() override { + complete_ = true; + return FilterStatus::Continue; + } + + // ThriftObject + bool onData(Buffer::Instance& buffer) override; + + // ThriftStruct + const ThriftFieldPtrList& fields() const override { return ThriftStructValueImpl::fields(); } + +private: + DecoderPtr decoder_; + bool complete_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/stat_sinks/metrics_service/config.cc b/source/extensions/stat_sinks/metrics_service/config.cc index 234b46b377406..22e9e58fad1ff 100644 --- a/source/extensions/stat_sinks/metrics_service/config.cc +++ b/source/extensions/stat_sinks/metrics_service/config.cc @@ -29,7 +29,7 @@ Stats::SinkPtr MetricsServiceSinkFactory::createStatsSink(const Protobuf::Messag grpc_service, server.stats(), false), server.threadLocal(), server.localInfo()); - return std::make_unique(grpc_metrics_streamer); + return std::make_unique(grpc_metrics_streamer, server.timeSource()); } ProtobufTypes::MessagePtr MetricsServiceSinkFactory::createEmptyConfigProto() { diff --git a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc index ce68afd5f98aa..b218a57f1c74e 100644 --- a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc +++ b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc @@ -61,15 +61,18 @@ void GrpcMetricsStreamerImpl::ThreadLocalStreamer::send( } } -MetricsServiceSink::MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer) - : grpc_metrics_streamer_(grpc_metrics_streamer) {} +MetricsServiceSink::MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer, + TimeSource& time_source) + : grpc_metrics_streamer_(grpc_metrics_streamer), time_source_(time_source) {} void MetricsServiceSink::flushCounter(const Stats::Counter& counter) { io::prometheus::client::MetricFamily* metrics_family = message_.add_envoy_metrics(); metrics_family->set_type(io::prometheus::client::MetricType::COUNTER); metrics_family->set_name(counter.name()); auto* metric = metrics_family->add_metric(); - metric->set_timestamp_ms(std::chrono::system_clock::now().time_since_epoch().count()); + metric->set_timestamp_ms(std::chrono::duration_cast( + time_source_.systemTime().time_since_epoch()) + .count()); auto* counter_metric = metric->mutable_counter(); counter_metric->set_value(counter.value()); } @@ -79,7 +82,9 @@ void MetricsServiceSink::flushGauge(const Stats::Gauge& gauge) { metrics_family->set_type(io::prometheus::client::MetricType::GAUGE); metrics_family->set_name(gauge.name()); auto* metric = metrics_family->add_metric(); - metric->set_timestamp_ms(std::chrono::system_clock::now().time_since_epoch().count()); + metric->set_timestamp_ms(std::chrono::duration_cast( + time_source_.systemTime().time_since_epoch()) + .count()); auto* gauage_metric = metric->mutable_gauge(); gauage_metric->set_value(gauge.value()); } @@ -88,7 +93,9 @@ void MetricsServiceSink::flushHistogram(const Stats::ParentHistogram& histogram) metrics_family->set_type(io::prometheus::client::MetricType::SUMMARY); metrics_family->set_name(histogram.name()); auto* metric = metrics_family->add_metric(); - metric->set_timestamp_ms(std::chrono::system_clock::now().time_since_epoch().count()); + metric->set_timestamp_ms(std::chrono::duration_cast( + time_source_.systemTime().time_since_epoch()) + .count()); auto* summary_metric = metric->mutable_summary(); const Stats::HistogramStatistics& hist_stats = histogram.intervalStatistics(); for (size_t i = 0; i < hist_stats.supportedQuantiles().size(); i++) { diff --git a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h index b27792652d788..bb796052a311b 100644 --- a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h +++ b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h @@ -111,7 +111,8 @@ class GrpcMetricsStreamerImpl : public Singleton::Instance, public GrpcMetricsSt class MetricsServiceSink : public Stats::Sink { public: // MetricsService::Sink - MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer); + MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer, + TimeSource& time_source); void flush(Stats::Source& source) override; void onHistogramComplete(const Stats::Histogram&, uint64_t) override {} @@ -122,6 +123,7 @@ class MetricsServiceSink : public Stats::Sink { private: GrpcMetricsStreamerSharedPtr grpc_metrics_streamer_; envoy::service::metrics::v2::StreamMetricsMessage message_; + TimeSource& time_source_; }; } // namespace MetricsService diff --git a/test/common/http/conn_manager_impl_test.cc b/test/common/http/conn_manager_impl_test.cc index 90e6e21f45a30..266a705a5fb97 100644 --- a/test/common/http/conn_manager_impl_test.cc +++ b/test/common/http/conn_manager_impl_test.cc @@ -1705,6 +1705,10 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { Buffer::OwnedImpl fake_input("1234"); conn_manager_->onData(fake_input, false); + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks = nullptr; + EXPECT_CALL(*conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) { upstream_callbacks = &cb; })); conn_pool_.host_->hostname_ = "newhost"; conn_pool_.poolReady(upstream_conn_); @@ -1714,6 +1718,7 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { EXPECT_EQ(1U, stats_.named_.downstream_cx_websocket_total_.value()); EXPECT_EQ(0U, stats_.named_.downstream_cx_http1_active_.value()); + upstream_callbacks->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); EXPECT_EQ(0U, stats_.named_.downstream_cx_websocket_active_.value()); @@ -1753,8 +1758,13 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyData) { EXPECT_CALL(upstream_conn_, write(_, false)); EXPECT_CALL(upstream_conn_, write(BufferEqual(&early_data), false)); EXPECT_CALL(filter_callbacks_.connection_, readDisable(false)); + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks = nullptr; + EXPECT_CALL(*conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) { upstream_callbacks = &cb; })); conn_pool_.poolReady(upstream_conn_); + upstream_callbacks->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } @@ -1828,7 +1838,12 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyEndStream) { EXPECT_CALL(upstream_conn_, write(_, false)); EXPECT_CALL(upstream_conn_, write(_, true)).Times(0); + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks = nullptr; + EXPECT_CALL(*conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) { upstream_callbacks = &cb; })); conn_pool_.poolReady(upstream_conn_); + upstream_callbacks->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index 06837dfbd450c..8aeaf11a978b9 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -214,6 +214,8 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { EXPECT_CALL(upstream_connection, write(BufferEqual(&buffer), _)); read_buffer_.add("hello"); manager.onRead(); + + connection.raiseEvent(ConnectionEvent::RemoteClose); } } // namespace Network diff --git a/test/common/router/config_impl_test.cc b/test/common/router/config_impl_test.cc index 9951064466b8a..b063d6408bf79 100644 --- a/test/common/router/config_impl_test.cc +++ b/test/common/router/config_impl_test.cc @@ -535,6 +535,31 @@ TEST(RouteMatcherTest, TestRoutes) { } } +TEST(RouteMatcherTest, TestRoutesWithWildcardAndDefaultOnly) { + std::string yaml = R"EOF( +virtual_hosts: + - name: wildcard + domains: ["*.solo.io"] + routes: + - match: { prefix: "/" } + route: { cluster: "wildcard" } + - name: default + domains: ["*"] + routes: + - match: { prefix: "/" } + route: { cluster: "default" } + )EOF"; + + const auto proto_config = parseRouteConfigurationFromV2Yaml(yaml); + NiceMock factory_context; + TestConfigImpl config(proto_config, factory_context, true); + + EXPECT_EQ("wildcard", + config.route(genHeaders("gloo.solo.io", "/", "GET"), 0)->routeEntry()->clusterName()); + EXPECT_EQ("default", + config.route(genHeaders("example.com", "/", "GET"), 0)->routeEntry()->clusterName()); +} + TEST(RouteMatcherTest, TestRoutesWithInvalidRegex) { std::string invalid_route = R"EOF( virtual_hosts: @@ -1699,6 +1724,65 @@ TEST_F(RouterMatcherHashPolicyTest, HashMultiple) { EXPECT_NE(hash_h, hash_both); } +TEST_F(RouterMatcherHashPolicyTest, HashTerminal) { + // Hash policy list: cookie, header [terminal=true], user_ip. + auto route = route_config_.mutable_virtual_hosts(0)->mutable_routes(0)->mutable_route(); + route->add_hash_policy()->mutable_cookie()->set_name("cookie_hash"); + auto* header_hash = route->add_hash_policy(); + header_hash->mutable_header()->set_header_name("foo_header"); + header_hash->set_terminal(true); + route->add_hash_policy()->mutable_connection_properties()->set_source_ip(true); + Network::Address::Ipv4Instance address1("4.3.2.1"); + Network::Address::Ipv4Instance address2("1.2.3.4"); + + uint64_t hash_1, hash_2; + // Test terminal works when there is hash computed, the rest of the policy + // list is ignored. + { + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + headers.addCopy("Cookie", "cookie_hash=foo;"); + headers.addCopy("foo_header", "bar"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_1 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address1, headers, add_cookie_nop_) + .value(); + } + { + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + headers.addCopy("Cookie", "cookie_hash=foo;"); + headers.addCopy("foo_header", "bar"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_2 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address2, headers, add_cookie_nop_) + .value(); + } + EXPECT_EQ(hash_1, hash_2); + + // If no hash computed after evaluating a hash policy, the rest of the policy + // list is evaluated. + { + // Input: {}, {}, address1. Hash on address1. + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_1 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address1, headers, add_cookie_nop_) + .value(); + } + { + // Input: {}, {}, address2. Hash on address2. + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_2 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address2, headers, add_cookie_nop_) + .value(); + } + EXPECT_NE(hash_1, hash_2); +} + TEST_F(RouterMatcherHashPolicyTest, InvalidHashPolicies) { NiceMock factory_context; { diff --git a/test/common/secret/sds_api_test.cc b/test/common/secret/sds_api_test.cc index 17435fe5c8d80..7a4432e4ff744 100644 --- a/test/common/secret/sds_api_test.cc +++ b/test/common/secret/sds_api_test.cc @@ -41,8 +41,9 @@ TEST_F(SdsApiTest, BasicTest) { auto google_grpc = grpc_service->mutable_google_grpc(); google_grpc->set_target_uri("fake_address"); google_grpc->set_stat_prefix("test"); - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); NiceMock* grpc_client{new NiceMock()}; NiceMock* factory{new NiceMock()}; @@ -57,13 +58,15 @@ TEST_F(SdsApiTest, BasicTest) { init_manager.initialize(); } -// Validate that SdsApi updates secrets successfully if a good secret is passed to onConfigUpdate(). -TEST_F(SdsApiTest, SecretUpdateSuccess) { +// Validate that TlsCertificateSdsApi updates secrets successfully if a good secret +// is passed to onConfigUpdate(). +TEST_F(SdsApiTest, DynamicTlsCertificateUpdateSuccess) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); NiceMock secret_callback; auto handle = @@ -96,13 +99,49 @@ TEST_F(SdsApiTest, SecretUpdateSuccess) { handle->remove(); } +// Validate that CertificateValidationContextSdsApi updates secrets successfully if +// a good secret is passed to onConfigUpdate(). +TEST_F(SdsApiTest, DynamicCertificateValidationContextUpdateSuccess) { + NiceMock server; + NiceMock init_manager; + envoy::api::v2::core::ConfigSource config_source; + CertificateValidationContextSdsApi sds_api( + server.localInfo(), server.dispatcher(), server.random(), server.stats(), + server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + + NiceMock secret_callback; + auto handle = + sds_api.addUpdateCallback([&secret_callback]() { secret_callback.onAddOrUpdateSecret(); }); + + std::string yaml = + R"EOF( + name: "abc.com" + validation_context: + trusted_ca: { filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" } + allow_expired_certificate: true + )EOF"; + + Protobuf::RepeatedPtrField secret_resources; + auto secret_config = secret_resources.Add(); + MessageUtil::loadFromYaml(TestEnvironment::substitute(yaml), *secret_config); + EXPECT_CALL(secret_callback, onAddOrUpdateSecret()); + sds_api.onConfigUpdate(secret_resources, ""); + + const std::string ca_cert = "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem"; + EXPECT_EQ(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(ca_cert)), + sds_api.secret()->caCert()); + + handle->remove(); +} + // Validate that SdsApi throws exception if an empty secret is passed to onConfigUpdate(). TEST_F(SdsApiTest, EmptyResource) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); Protobuf::RepeatedPtrField secret_resources; @@ -115,8 +154,9 @@ TEST_F(SdsApiTest, SecretUpdateWrongSize) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); std::string yaml = R"EOF( @@ -144,8 +184,9 @@ TEST_F(SdsApiTest, SecretUpdateWrongSecretName) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); std::string yaml = R"EOF( diff --git a/test/common/secret/secret_manager_impl_test.cc b/test/common/secret/secret_manager_impl_test.cc index c340e63342f47..4432d9e94e364 100644 --- a/test/common/secret/secret_manager_impl_test.cc +++ b/test/common/secret/secret_manager_impl_test.cc @@ -171,7 +171,7 @@ name: "abc.com" Protobuf::RepeatedPtrField secret_resources; auto secret_config = secret_resources.Add(); MessageUtil::loadFromYaml(TestEnvironment::substitute(yaml), *secret_config); - dynamic_cast(*secret_provider).onConfigUpdate(secret_resources, ""); + dynamic_cast(*secret_provider).onConfigUpdate(secret_resources, ""); const std::string cert_pem = "{{ test_rundir }}/test/common/ssl/test_data/selfsigned_cert.pem"; EXPECT_EQ(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(cert_pem)), secret_provider->secret()->certificateChain()); diff --git a/test/common/tcp_proxy/tcp_proxy_test.cc b/test/common/tcp_proxy/tcp_proxy_test.cc index f447f62b08e3e..1df01629ffa81 100644 --- a/test/common/tcp_proxy/tcp_proxy_test.cc +++ b/test/common/tcp_proxy/tcp_proxy_test.cc @@ -348,6 +348,12 @@ class TcpProxyTest : public testing::Test { .WillByDefault(SaveArg<0>(&access_log_data_)); } + ~TcpProxyTest() { + if (filter_ != nullptr) { + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); + } + } + void configure(const envoy::config::filter::network::tcp_proxy::v2::TcpProxy& config) { config_.reset(new Config(config, factory_context_)); } @@ -734,6 +740,22 @@ TEST_F(TcpProxyTest, DisconnectBeforeData) { filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); } +// Test that if the downstream connection is closed before the upstream connection +// is established, the upstream connection is cancelled. +TEST_F(TcpProxyTest, RemoteClosetBeforeUpstreamConnected) { + setup(1); + EXPECT_CALL(*conn_pool_handles_.at(0), cancel()); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); +} + +// Test that if the downstream connection is closed before the upstream connection +// is established, the upstream connection is cancelled. +TEST_F(TcpProxyTest, LocalClosetBeforeUpstreamConnected) { + setup(1); + EXPECT_CALL(*conn_pool_handles_.at(0), cancel()); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::LocalClose); +} + TEST_F(TcpProxyTest, UpstreamConnectFailure) { setup(1, accessLogConfig("%RESPONSE_FLAGS%")); @@ -873,6 +895,7 @@ TEST_F(TcpProxyTest, IdleTimeoutWithOutstandingDataFlushed) { TEST_F(TcpProxyTest, AccessLogUpstreamHost) { setup(1, accessLogConfig("%UPSTREAM_HOST% %UPSTREAM_CLUSTER%")); raiseEventUpstreamConnected(0); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_EQ(access_log_data_, "127.0.0.1:80 fake_cluster"); } @@ -881,6 +904,7 @@ TEST_F(TcpProxyTest, AccessLogUpstreamHost) { TEST_F(TcpProxyTest, AccessLogUpstreamLocalAddress) { setup(1, accessLogConfig("%UPSTREAM_LOCAL_ADDRESS%")); raiseEventUpstreamConnected(0); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_EQ(access_log_data_, "2.2.2.2:50000"); } @@ -893,6 +917,7 @@ TEST_F(TcpProxyTest, AccessLogDownstreamAddress) { filter_callbacks_.connection_.remote_address_ = Network::Utility::resolveUrl("tcp://1.1.1.1:40000"); setup(1, accessLogConfig("%DOWNSTREAM_REMOTE_ADDRESS_WITHOUT_PORT% %DOWNSTREAM_LOCAL_ADDRESS%")); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_EQ(access_log_data_, "1.1.1.1 1.1.1.2:20000"); } @@ -1075,6 +1100,9 @@ TEST_F(TcpProxyRoutingTest, NonRoutableConnection) { EXPECT_EQ(total_cx + 1, config_->stats().downstream_cx_total_.value()); EXPECT_EQ(non_routable_cx + 1, config_->stats().downstream_cx_no_route_.value()); + + // Cleanup + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); } TEST_F(TcpProxyRoutingTest, RoutableConnection) { @@ -1087,7 +1115,8 @@ TEST_F(TcpProxyRoutingTest, RoutableConnection) { connection_.local_address_ = std::make_shared("1.2.3.4", 9999); // Expect filter to try to open a connection to specified cluster. - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)); + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(nullptr)); filter_->onNewConnection(); diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index 1cc27974ea7b2..7a9d3c4b23ff3 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -245,6 +245,19 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "thrift_object_impl_test", + srcs = ["thrift_object_impl_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:thrift_object_lib", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + ], +) + envoy_extension_cc_test( name = "integration_test", srcs = ["integration_test.cc"], diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 43f3352149723..338ce458123df 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -2,9 +2,12 @@ #include "common/buffer/buffer_impl.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/buffer_helper.h" #include "extensions/filters/network/thrift_proxy/config.h" #include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/header_transport_impl.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" @@ -17,8 +20,10 @@ #include "gtest/gtest.h" using testing::_; +using testing::AnyNumber; using testing::Invoke; using testing::NiceMock; +using testing::Ref; using testing::Return; using testing::ReturnRef; @@ -39,10 +44,23 @@ class TestConfigImpl : public ConfigImpl { void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { callbacks.addDecoderFilter(decoder_filter_); } + TransportPtr createTransport() override { + if (transport_) { + return TransportPtr{transport_}; + } + return ConfigImpl::createTransport(); + } + ProtocolPtr createProtocol() override { + if (protocol_) { + return ProtocolPtr{protocol_}; + } + return ConfigImpl::createProtocol(); + } -private: ThriftFilters::DecoderFilterSharedPtr decoder_filter_; ThriftFilterStats& stats_; + MockTransport* transport_{}; + MockProtocol* protocol_{}; }; class ThriftConnectionManagerTest : public testing::Test { @@ -72,7 +90,14 @@ class ThriftConnectionManagerTest : public testing::Test { proto_config_.set_stat_prefix("test"); decoder_filter_.reset(new NiceMock()); + config_.reset(new TestConfigImpl(proto_config_, context_, decoder_filter_, stats_)); + if (custom_transport_) { + config_->transport_ = custom_transport_; + } + if (custom_protocol_) { + config_->protocol_ = custom_protocol_; + } filter_.reset(new ConnectionManager(*config_)); filter_->initializeReadFilterCallbacks(filter_callbacks_); @@ -267,6 +292,9 @@ class ThriftConnectionManagerTest : public testing::Test { Buffer::OwnedImpl write_buffer_; std::unique_ptr filter_; NiceMock filter_callbacks_; + + MockTransport* custom_transport_{}; + MockProtocol* custom_protocol_{}; }; TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { @@ -602,7 +630,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -634,7 +664,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { writeFramedBinaryTApplicationException(write_buffer_, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -667,7 +699,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { writeFramedBinaryIDLException(write_buffer_, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -700,7 +734,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { // Call is not valid in a response writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -739,7 +775,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { 0x08, 0xff, 0xff // illegal field id }); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_, write(_, false)); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); @@ -781,7 +819,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndTransportApplicationException) { 0x01, 0x02, 0x00, 0x00, // transforms: 1, 2; padding }); - callbacks->startUpstreamResponse(TransportType::Header, ProtocolType::Binary); + HeaderTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -817,15 +857,18 @@ TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(2); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); - callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + callbacks.front()->startUpstreamResponse(transport, proto); EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); callbacks.pop_front(); EXPECT_EQ(1U, store_.counter("test.response").value()); EXPECT_EQ(1U, store_.counter("test.response_reply").value()); writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x02); - callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + callbacks.front()->startUpstreamResponse(transport, proto); EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); callbacks.pop_front(); EXPECT_EQ(2U, store_.counter("test.response").value()); @@ -857,6 +900,68 @@ TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { EXPECT_EQ(0U, store_.gauge("test.request_active").value()); } +TEST_F(ThriftConnectionManagerTest, DownstreamProtocolUpgrade) { + custom_transport_ = new NiceMock(); + custom_protocol_ = new NiceMock(); + initializeFilter(); + + EXPECT_CALL(*custom_transport_, decodeFrameStart(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*custom_protocol_, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMessageType(MessageType::Call); + metadata.setProtocolUpgradeMessage(true); + return true; + })); + EXPECT_CALL(*custom_protocol_, supportsUpgrade()).Times(AnyNumber()).WillRepeatedly(Return(true)); + + MockDecoderEventHandler* upgrade_decoder = new NiceMock(); + EXPECT_CALL(*custom_protocol_, upgradeRequestDecoder()) + .WillOnce(Invoke([&]() -> DecoderEventHandlerSharedPtr { + return DecoderEventHandlerSharedPtr{upgrade_decoder}; + })); + EXPECT_CALL(*upgrade_decoder, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, structBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readFieldBegin(_, _, _, _)) + .WillOnce(Invoke( + [&](Buffer::Instance&, std::string&, FieldType& field_type, int16_t& field_id) -> bool { + field_type = FieldType::Stop; + field_id = 0; + return true; + })); + EXPECT_CALL(*custom_protocol_, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, structEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_transport_, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, transportEnd()).WillOnce(Return(FilterStatus::Continue)); + + MockDirectResponse* direct_response = new NiceMock(); + + EXPECT_CALL(*custom_protocol_, upgradeResponse(Ref(*upgrade_decoder))) + .WillOnce(Invoke([&](const DecoderEventHandler&) -> DirectResponsePtr { + return DirectResponsePtr{direct_response}; + })); + + EXPECT_CALL(*direct_response, encode(_, Ref(*custom_protocol_), _)) + .WillOnce(Invoke([&](MessageMetadata&, Protocol&, Buffer::Instance& buffer) -> void { + buffer.add("response"); + })); + EXPECT_CALL(*custom_transport_, encodeFrame(_, _, _)) + .WillOnce(Invoke( + [&](Buffer::Instance& buffer, const MessageMetadata&, Buffer::Instance& message) -> void { + EXPECT_EQ("response", message.toString()); + buffer.add("transport-encoded response"); + })); + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ("transport-encoded response", buffer.toString()); + })); + + Buffer::OwnedImpl buffer; + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 6fd60d2256814..a8369e757b220 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -710,17 +710,17 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { } TEST(DecoderTest, OnData) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; @@ -732,7 +732,7 @@ TEST(DecoderTest, OnData) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -750,18 +750,18 @@ TEST(DecoderTest, OnData) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); bool underflow = false; @@ -770,24 +770,24 @@ TEST(DecoderTest, OnData) { } TEST(DecoderTest, OnDataWithProtocolHint) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); metadata.setProtocol(ProtocolType::Binary); return true; })); - EXPECT_CALL(*proto, type()).WillOnce(Return(ProtocolType::Auto)); - EXPECT_CALL(*proto, setType(ProtocolType::Binary)); + EXPECT_CALL(proto, type()).WillOnce(Return(ProtocolType::Auto)); + EXPECT_CALL(proto, setType(ProtocolType::Binary)); EXPECT_CALL(handler, transportBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { EXPECT_TRUE(metadata->hasFrameSize()); @@ -799,7 +799,7 @@ TEST(DecoderTest, OnDataWithProtocolHint) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -817,18 +817,18 @@ TEST(DecoderTest, OnDataWithProtocolHint) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); bool underflow = false; @@ -837,23 +837,23 @@ TEST(DecoderTest, OnDataWithProtocolHint) { } TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); metadata.setProtocol(ProtocolType::Binary); return true; })); - EXPECT_CALL(*proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); + EXPECT_CALL(proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); bool underflow = false; EXPECT_THROW_WITH_MESSAGE(decoder.onData(buffer, underflow), EnvoyException, @@ -861,17 +861,17 @@ TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { } TEST(DecoderTest, OnDataThrowsTransportAppException) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setAppException(AppExceptionType::InvalidTransform, "unknown xform"); return true; @@ -882,85 +882,85 @@ TEST(DecoderTest, OnDataThrowsTransportAppException) { } TEST(DecoderTest, OnDataResumes) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; buffer.add("x"); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(false)); bool underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); // buffer.length() == 1 } TEST(DecoderTest, OnDataResumesTransportFrameStart) { - StrictMock* transport = new StrictMock(); - StrictMock* proto = new StrictMock(); + StrictMock transport; + StrictMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); - EXPECT_CALL(*transport, name()).Times(AnyNumber()); - EXPECT_CALL(*proto, name()).Times(AnyNumber()); + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; bool underflow = false; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); @@ -968,66 +968,65 @@ TEST(DecoderTest, OnDataResumesTransportFrameStart) { } TEST(DecoderTest, OnDataResumesTransportFrameEnd) { - StrictMock* transport = new StrictMock(); - StrictMock* proto = new StrictMock(); + StrictMock transport; + StrictMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); - EXPECT_CALL(*transport, name()).Times(AnyNumber()); - EXPECT_CALL(*proto, name()).Times(AnyNumber()); + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(false)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(false)); bool underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); // buffer.length() == 0 } TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { + StrictMock transport; + EXPECT_CALL(transport, name()).WillRepeatedly(ReturnRef(transport.name_)); - StrictMock* transport = new StrictMock(); - EXPECT_CALL(*transport, name()).WillRepeatedly(ReturnRef(transport->name_)); - - StrictMock* proto = new StrictMock(); - EXPECT_CALL(*proto, name()).WillRepeatedly(ReturnRef(proto->name_)); + StrictMock proto; + EXPECT_CALL(proto, name()).WillRepeatedly(ReturnRef(proto.name_)); NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; bool underflow = true; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; @@ -1042,7 +1041,7 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -1062,42 +1061,42 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())) .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::I32), SetArgReferee<3>(1), Return(true))); EXPECT_CALL(handler, fieldBegin(absl::string_view(), FieldType::I32, 1)) .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readInt32(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(true)); EXPECT_CALL(handler, int32Value(_)).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, fieldEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index ad14d3c1149e1..ac0359dce7c17 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -27,6 +27,7 @@ MockProtocol::MockProtocol() { ON_CALL(*this, setType(_)).WillByDefault(Invoke([&](ProtocolType type) -> void { type_ = type; })); + ON_CALL(*this, supportsUpgrade()).WillByDefault(Return(false)); } MockProtocol::~MockProtocol() {} @@ -39,6 +40,9 @@ MockDecoderEventHandler::~MockDecoderEventHandler() {} MockDirectResponse::MockDirectResponse() {} MockDirectResponse::~MockDirectResponse() {} +MockThriftObject::MockThriftObject() {} +MockThriftObject::~MockThriftObject() {} + namespace ThriftFilters { MockDecoderFilter::MockDecoderFilter() { diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index b93b7717c8780..2067434b2db88 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -1,6 +1,7 @@ #pragma once #include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/conn_state.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/protocol.h" @@ -101,6 +102,12 @@ class MockProtocol : public Protocol { MOCK_METHOD2(writeDouble, void(Buffer::Instance& buffer, double value)); MOCK_METHOD2(writeString, void(Buffer::Instance& buffer, const std::string& value)); MOCK_METHOD2(writeBinary, void(Buffer::Instance& buffer, const std::string& value)); + MOCK_METHOD0(supportsUpgrade, bool()); + MOCK_METHOD0(upgradeRequestDecoder, DecoderEventHandlerSharedPtr()); + MOCK_METHOD1(upgradeResponse, DirectResponsePtr(const DecoderEventHandler&)); + MOCK_METHOD3(attemptUpgrade, + ThriftObjectPtr(Transport&, ThriftConnectionState&, Buffer::Instance&)); + MOCK_METHOD2(completeUpgrade, void(ThriftConnectionState&, ThriftObject&)); std::string name_{"mock"}; ProtocolType type_{ProtocolType::Auto}; @@ -154,6 +161,15 @@ class MockDirectResponse : public DirectResponse { MOCK_CONST_METHOD3(encode, void(MessageMetadata&, Protocol&, Buffer::Instance&)); }; +class MockThriftObject : public ThriftObject { +public: + MockThriftObject(); + ~MockThriftObject(); + + MOCK_CONST_METHOD0(fields, ThriftFieldPtrList&()); + MOCK_METHOD1(onData, bool(Buffer::Instance&)); +}; + namespace ThriftFilters { class MockDecoderFilter : public DecoderFilter { @@ -204,7 +220,7 @@ class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); MOCK_METHOD1(sendLocalReply, void(const DirectResponse&)); - MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); + MOCK_METHOD2(startUpstreamResponse, void(Transport&, Protocol&)); MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); MOCK_METHOD0(resetDownstreamConnection, void()); diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 3b07f4d9186dd..31ce50e8327c7 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -21,6 +21,7 @@ using testing::_; using testing::ContainsRegex; +using testing::InSequence; using testing::Invoke; using testing::NiceMock; using testing::Ref; @@ -71,8 +72,22 @@ class TestNamedProtocolConfigFactory : public NamedProtocolConfigFactory { class ThriftRouterTestBase { public: ThriftRouterTestBase() - : transport_factory_([&]() -> MockTransport* { return transport_; }), - protocol_factory_([&]() -> MockProtocol* { return protocol_; }), + : transport_factory_([&]() -> MockTransport* { + ASSERT(transport_ == nullptr); + transport_ = new NiceMock(); + if (mock_transport_cb_) { + mock_transport_cb_(transport_); + } + return transport_; + }), + protocol_factory_([&]() -> MockProtocol* { + ASSERT(protocol_ == nullptr); + protocol_ = new NiceMock(); + if (mock_protocol_cb_) { + mock_protocol_cb_(protocol_); + } + return protocol_; + }), transport_register_(transport_factory_), protocol_register_(protocol_factory_) {} void initializeRouter() { @@ -124,9 +139,6 @@ class ThriftRouterTestBase { upstream_callbacks_ = &cb; })); - protocol_ = new NiceMock(); - - ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { EXPECT_EQ(metadata_->methodName(), metadata.methodName()); @@ -165,15 +177,16 @@ class ThriftRouterTestBase { EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); - protocol_ = new NiceMock(); - ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); - EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) - .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { - EXPECT_EQ(metadata_->methodName(), metadata.methodName()); - EXPECT_EQ(metadata_->messageType(), metadata.messageType()); - EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); - })); + mock_protocol_cb_ = [&](MockProtocol* protocol) -> void { + ON_CALL(*protocol, type()).WillByDefault(Return(ProtocolType::Binary)); + EXPECT_CALL(*protocol, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + }; EXPECT_CALL(callbacks_, continueDecoding()).Times(0); EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_)) .WillOnce( @@ -240,8 +253,6 @@ class ThriftRouterTestBase { } void completeRequest() { - transport_ = new NiceMock(); - EXPECT_CALL(*protocol_, writeMessageEnd(_)); EXPECT_CALL(*transport_, encodeFrame(_, _, _)); EXPECT_CALL(upstream_connection_, write(_, false)); @@ -257,7 +268,7 @@ class ThriftRouterTestBase { void returnResponse() { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); upstream_callbacks_->onUpstreamData(buffer, false); @@ -277,6 +288,9 @@ class ThriftRouterTestBase { Registry::InjectFactory transport_register_; Registry::InjectFactory protocol_register_; + std::function mock_transport_cb_{}; + std::function mock_protocol_cb_{}; + NiceMock context_; NiceMock callbacks_; NiceMock* transport_{}; @@ -472,7 +486,7 @@ TEST_F(ThriftRouterTest, TruncatedResponse) { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, released(Ref(upstream_connection_))); EXPECT_CALL(callbacks_, resetDownstreamConnection()); @@ -531,7 +545,7 @@ TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))) .WillOnce(Invoke([&](Buffer::Instance&) -> bool { router_->resetUpstreamConnection(); @@ -590,6 +604,100 @@ TEST_F(ThriftRouterTest, UnexpectedRouterDestroy) { destroyRouter(); } +TEST_F(ThriftRouterTest, ProtocolUpgrade) { + initializeRouter(); + startRequest(MessageType::Call); + + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { upstream_callbacks_ = &cb; })); + + Tcp::ConnectionPool::ConnectionStatePtr conn_state; + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, setConnectionState_(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { conn_state.swap(cs); })); + + EXPECT_CALL(*protocol_, supportsUpgrade()).WillOnce(Return(true)); + + MockThriftObject* upgrade_response = new NiceMock(); + + EXPECT_CALL(*protocol_, attemptUpgrade(_, _, _)) + .WillOnce(Invoke( + [&](Transport&, ThriftConnectionState&, Buffer::Instance& buffer) -> ThriftObjectPtr { + buffer.add("upgrade request"); + return ThriftObjectPtr{upgrade_response}; + })); + EXPECT_CALL(upstream_connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ("upgrade request", buffer.toString()); + })); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + + Buffer::OwnedImpl buffer; + EXPECT_CALL(*upgrade_response, onData(Ref(buffer))).WillOnce(Return(false)); + upstream_callbacks_->onUpstreamData(buffer, false); + + EXPECT_CALL(*upgrade_response, onData(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(*protocol_, completeUpgrade(_, Ref(*upgrade_response))); + EXPECT_CALL(callbacks_, continueDecoding()); + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + upstream_callbacks_->onUpstreamData(buffer, false); + + // Then the actual request... + sendTrivialStruct(FieldType::String); + completeRequest(); + returnResponse(); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, ProtocolUpgradeSkippedOnExistingConnection) { + initializeRouter(); + startRequest(MessageType::Call); + + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { upstream_callbacks_ = &cb; })); + + Tcp::ConnectionPool::ConnectionStatePtr conn_state = std::make_unique(); + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + + EXPECT_CALL(*protocol_, supportsUpgrade()).WillOnce(Return(true)); + + // Protocol determines that connection state shows upgrade already occurred + EXPECT_CALL(*protocol_, attemptUpgrade(_, _, _)) + .WillOnce(Invoke([&](Transport&, ThriftConnectionState&, + Buffer::Instance&) -> ThriftObjectPtr { return nullptr; })); + + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + EXPECT_CALL(callbacks_, continueDecoding()); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + + // Then the actual request... + sendTrivialStruct(FieldType::String); + completeRequest(); + returnResponse(); + destroyRouter(); +} + TEST_P(ThriftRouterFieldTypeTest, OneWay) { FieldType field_type = GetParam(); diff --git a/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc b/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc new file mode 100644 index 0000000000000..3e8d12403dcd8 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc @@ -0,0 +1,494 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/thrift_object_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::AnyNumber; +using testing::Expectation; +using testing::ExpectationSet; +using testing::InSequence; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; +using testing::Test; +using testing::TestWithParam; +using testing::Values; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class ThriftObjectImplTestBase { +public: + virtual ~ThriftObjectImplTestBase() {} + + Expectation expectValue(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + return EXPECT_CALL(protocol_, readBool(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, bool& value) -> bool { + value = true; + return true; + })); + case FieldType::Byte: + return EXPECT_CALL(protocol_, readByte(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, uint8_t& value) -> bool { + value = 1; + return true; + })); + case FieldType::Double: + return EXPECT_CALL(protocol_, readDouble(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, double& value) -> bool { + value = 2.0; + return true; + })); + case FieldType::I16: + return EXPECT_CALL(protocol_, readInt16(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int16_t& value) -> bool { + value = 3; + return true; + })); + case FieldType::I32: + return EXPECT_CALL(protocol_, readInt32(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int32_t& value) -> bool { + value = 4; + return true; + })); + case FieldType::I64: + return EXPECT_CALL(protocol_, readInt64(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int64_t& value) -> bool { + value = 5; + return true; + })); + case FieldType::String: + return EXPECT_CALL(protocol_, readString(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, std::string& value) -> bool { + value = "six"; + return true; + })); + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + Expectation expectFieldBegin(FieldType field_type, int16_t field_id) { + return EXPECT_CALL(protocol_, readFieldBegin(Ref(buffer_), _, _, _)) + .WillOnce( + Invoke([=](Buffer::Instance&, std::string&, FieldType& type, int16_t& id) -> bool { + type = field_type; + id = field_id; + return true; + })); + } + + Expectation expectFieldEnd() { + return EXPECT_CALL(protocol_, readFieldEnd(Ref(buffer_))).WillOnce(Return(true)); + } + + ExpectationSet expectField(FieldType field_type, int16_t field_id) { + ExpectationSet s; + s += expectFieldBegin(field_type, field_id); + s += expectValue(field_type); + s += expectFieldEnd(); + return s; + } + + Expectation expectStopField() { return expectFieldBegin(FieldType::Stop, 0); } + + void checkValue(FieldType field_type, const ThriftValue& value) { + EXPECT_EQ(field_type, value.type()); + + switch (field_type) { + case FieldType::Bool: + EXPECT_EQ(true, value.getValueTyped()); + break; + case FieldType::Byte: + EXPECT_EQ(1, value.getValueTyped()); + break; + case FieldType::Double: + EXPECT_EQ(2.0, value.getValueTyped()); + break; + case FieldType::I16: + EXPECT_EQ(3, value.getValueTyped()); + break; + case FieldType::I32: + EXPECT_EQ(4, value.getValueTyped()); + break; + case FieldType::I64: + EXPECT_EQ(5, value.getValueTyped()); + break; + case FieldType::String: + EXPECT_EQ("six", value.getValueTyped()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + void checkFieldValue(const ThriftField& field) { + const ThriftValue& value = field.getValue(); + checkValue(field.fieldType(), value); + } + + NiceMock transport_; + NiceMock protocol_; + Buffer::OwnedImpl buffer_; +}; + +class ThriftObjectImplTest : public ThriftObjectImplTestBase, public Test {}; + +// Test parsing an empty struct (just a stop field). +TEST_F(ThriftObjectImplTest, ParseEmptyStruct) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_TRUE(thrift_obj.fields().empty()); +} + +class ThriftObjectImplValueTest : public ThriftObjectImplTestBase, + public TestWithParam {}; + +INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, ThriftObjectImplValueTest, + Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, + FieldType::I32, FieldType::I64, FieldType::String), + fieldTypeParamToString); + +// Test parsing a struct with a single field with a simple value. +TEST_P(ThriftObjectImplValueTest, ParseSingleValueStruct) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(field_type, 1); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + EXPECT_EQ(field_type, thrift_obj.fields().front()->fieldType()); + EXPECT_EQ(1, thrift_obj.fields().front()->fieldId()); + checkFieldValue(*thrift_obj.fields().front()); +} + +// Test parsing nested structs (struct -> struct -> simple field). +TEST_P(ThriftObjectImplValueTest, ParseNestedSingleValueStruct) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Struct, 1); + + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(field_type, 2); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(FieldType::Struct, field.fieldType()); + + const ThriftStructValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(1, nested.fields().size()); + EXPECT_EQ(field_type, nested.fields().front()->fieldType()); + EXPECT_EQ(2, nested.fields().front()->fieldId()); + checkFieldValue(*nested.fields().front()); +} + +// Test parsing a struct with a single list field (struct -> list). +TEST_P(ThriftObjectImplValueTest, ParseNestedListValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::List, 1); + + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& type, uint32_t& size) -> bool { + type = field_type; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(field_type); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::List, field.fieldType()); + + const ThriftListValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.elementType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value); + } +} + +// Test parsing a struct with a single set field (struct -> set). +TEST_P(ThriftObjectImplValueTest, ParseNestedSetValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Set, 1); + + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& type, uint32_t& size) -> bool { + type = field_type; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(field_type); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::Set, field.fieldType()); + + const ThriftSetValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.elementType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value); + } +} + +// Test parsing a struct with a single map field (struct -> map). +TEST_P(ThriftObjectImplValueTest, ParseNestedMapValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Map, 1); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = field_type; + value_type = FieldType::String; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(FieldType::String); + expectValue(field_type); + expectValue(FieldType::String); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::Map, field.fieldType()); + + const ThriftMapValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.keyType()); + EXPECT_EQ(FieldType::String, nested.valueType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value.first); + checkValue(FieldType::String, *value.second); + } +} + +// Test a struct with a map -> list -> set -> map -> list -> set -> struct. +TEST_F(ThriftObjectImplTest, DeeplyNestedStruct) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Map, 1); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = FieldType::I32; + value_type = FieldType::List; + size = 1; + return true; + })); + expectValue(FieldType::I32); + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Set; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Map; + size = 1; + return true; + })); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = FieldType::I32; + value_type = FieldType::List; + size = 1; + return true; + })); + expectValue(FieldType::I32); + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Set; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Struct; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(FieldType::I64, 100); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + + EXPECT_EQ(FieldType::Map, thrift_obj.fields().front()->fieldType()); + const ThriftMapValue& map_value = + thrift_obj.fields().front()->getValue().getValueTyped(); + EXPECT_EQ(1, map_value.elements().size()); + + const ThriftListValue& list_value = + map_value.elements().front().second->getValueTyped(); + EXPECT_EQ(1, list_value.elements().size()); + + const ThriftSetValue& set_value = list_value.elements().front()->getValueTyped(); + EXPECT_EQ(1, set_value.elements().size()); + + const ThriftMapValue& map_value2 = set_value.elements().front()->getValueTyped(); + EXPECT_EQ(1, map_value2.elements().size()); + + const ThriftListValue& list_value2 = + map_value2.elements().front().second->getValueTyped(); + EXPECT_EQ(1, list_value2.elements().size()); + + const ThriftSetValue& set_value2 = + list_value2.elements().front()->getValueTyped(); + EXPECT_EQ(1, set_value2.elements().size()); + + const ThriftStructValue& struct_value = + set_value2.elements().front()->getValueTyped(); + EXPECT_EQ(1, struct_value.fields().size()); + + EXPECT_EQ(5, struct_value.fields().front()->getValue().getValueTyped()); +} + +// Tests when caller requests wrong value type. +TEST_F(ThriftObjectImplTest, WrongValueType) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(FieldType::String, 1); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + + const ThriftValue& value = thrift_obj.fields().front()->getValue(); + EXPECT_THROW_WITH_MESSAGE(value.getValueTyped(), EnvoyException, + fmt::format("expected field type {}, got {}", + static_cast(FieldType::I32), + static_cast(FieldType::String))); +} + +} // Namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc b/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc index f2996770d3a26..627eabf08dab0 100644 --- a/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc +++ b/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc @@ -1,5 +1,6 @@ #include "extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h" +#include "test/mocks/common.h" #include "test/mocks/grpc/mocks.h" #include "test/mocks/local_info/mocks.h" #include "test/mocks/stats/mocks.h" @@ -98,9 +99,10 @@ class MetricsServiceSinkTest : public testing::Test {}; TEST(MetricsServiceSinkTest, CheckSendCall) { NiceMock source; + NiceMock mock_time; std::shared_ptr streamer_{new MockGrpcMetricsStreamer()}; - MetricsServiceSink sink(streamer_); + MetricsServiceSink sink(streamer_, mock_time); auto counter = std::make_shared>(); counter->name_ = "test_counter"; @@ -125,9 +127,10 @@ TEST(MetricsServiceSinkTest, CheckSendCall) { TEST(MetricsServiceSinkTest, CheckStatsCount) { NiceMock source; + NiceMock mock_time; std::shared_ptr streamer_{new TestGrpcMetricsStreamer()}; - MetricsServiceSink sink(streamer_); + MetricsServiceSink sink(streamer_, mock_time); auto counter = std::make_shared>(); counter->name_ = "test_counter"; diff --git a/test/integration/echo_integration_test.cc b/test/integration/echo_integration_test.cc index 1eef8fb6a0ebb..d38c2eb131569 100644 --- a/test/integration/echo_integration_test.cc +++ b/test/integration/echo_integration_test.cc @@ -130,7 +130,12 @@ TEST_P(EchoIntegrationTest, AddRemoveListener) { RawConnectionDriver connection2( new_listener_port, buffer, [&](Network::ClientConnection&, const Buffer::Instance&) -> void { FAIL(); }, version_); - connection2.run(Event::Dispatcher::RunType::NonBlock); + while (connection2.connecting()) { + // Don't busy loop, but OS X often needs a moment to decide this connection isn't happening. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + connection2.run(Event::Dispatcher::RunType::NonBlock); + } if (connection2.connection().state() == Network::Connection::State::Closed) { connect_fail = true; break; diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index 5a370c556535f..571184ffa75bf 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -248,7 +248,7 @@ AssertionResult FakeConnectionBase::waitForDisconnect(bool ignore_spurious_event Thread::LockGuard lock(lock_); while (shared_connection_.connected()) { if (std::chrono::steady_clock::now() >= end_time) { - return AssertionResult("Timed out waiting for disconnect."); + return AssertionFailure() << "Timed out waiting for disconnect."; } Thread::CondVar::WaitStatus status = connection_event_.waitFor(lock_, 5ms); // The default behavior of waitForDisconnect is to assume the test cleanly @@ -300,7 +300,7 @@ AssertionResult FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_d Thread::LockGuard lock(lock_); while (new_streams_.empty()) { if (std::chrono::steady_clock::now() >= end_time) { - return AssertionResult("Timed out waiting for new stream."); + return AssertionFailure() << "Timed out waiting for new stream."; } Thread::CondVar::WaitStatus status = connection_event_.waitFor(lock_, 5ms); // As with waitForDisconnect, by default, waitForNewStream returns after the next event. @@ -536,7 +536,7 @@ AssertionResult FakeRawConnection::write(const std::string& data, bool end_strea Network::FilterStatus FakeRawConnection::ReadFilter::onData(Buffer::Instance& data, bool end_stream) { Thread::LockGuard lock(parent_.lock_); - ENVOY_LOG(debug, "got {} bytes", data.length()); + ENVOY_LOG(debug, "got {} bytes, end_stream {}", data.length(), end_stream); parent_.data_.append(data.toString()); parent_.half_closed_ = end_stream; data.drain(data.length()); diff --git a/test/integration/http_integration.cc b/test/integration/http_integration.cc index 293d761d93947..36bf5c9ad2ccb 100644 --- a/test/integration/http_integration.cc +++ b/test/integration/http_integration.cc @@ -584,6 +584,13 @@ void HttpIntegrationTest::testRouterDownstreamDisconnectBeforeRequestComplete( void HttpIntegrationTest::testRouterDownstreamDisconnectBeforeResponseComplete( ConnectionCreationFunction* create_connection) { +#ifdef __APPLE__ + // Skip this test on OS X: we can't detect the early close on OS X, and we + // won't clean up the upstream connection until it times out. See #4294. + if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { + return; + } +#endif initialize(); codec_client_ = makeHttpConnection( create_connection ? ((*create_connection)()) : makeClientConnection((lookupPort("http")))); diff --git a/test/integration/sds_static_integration_test.cc b/test/integration/sds_static_integration_test.cc index cc259d31ffe7f..f7551a751cd2e 100644 --- a/test/integration/sds_static_integration_test.cc +++ b/test/integration/sds_static_integration_test.cc @@ -47,16 +47,20 @@ class SdsStaticDownstreamIntegrationTest ->mutable_common_tls_context(); common_tls_context->add_alpn_protocols("http/1.1"); - auto* validation_context = common_tls_context->mutable_validation_context(); + common_tls_context->mutable_validation_context_sds_secret_config()->set_name( + "validation_context"); + common_tls_context->add_tls_certificate_sds_secret_configs()->set_name("server_cert"); + + auto* secret = bootstrap.mutable_static_resources()->add_secrets(); + secret->set_name("validation_context"); + auto* validation_context = secret->mutable_validation_context(); validation_context->mutable_trusted_ca()->set_filename( TestEnvironment::runfilesPath("test/config/integration/certs/cacert.pem")); validation_context->add_verify_certificate_hash( "E0:F3:C8:CE:5E:2E:A3:05:F0:70:1F:F5:12:E3:6E:2E:" "97:92:82:84:A2:28:BC:F7:73:32:D3:39:30:A1:B6:FD"); - common_tls_context->add_tls_certificate_sds_secret_configs()->set_name("server_cert"); - - auto* secret = bootstrap.mutable_static_resources()->add_secrets(); + secret = bootstrap.mutable_static_resources()->add_secrets(); secret->set_name("server_cert"); auto* tls_certificate = secret->mutable_tls_certificate(); tls_certificate->mutable_certificate_chain()->set_filename( diff --git a/test/integration/ssl_integration_test.cc b/test/integration/ssl_integration_test.cc index f72f0e7c518c0..f79f6f781c769 100644 --- a/test/integration/ssl_integration_test.cc +++ b/test/integration/ssl_integration_test.cc @@ -154,6 +154,13 @@ TEST_P(SslIntegrationTest, RouterDownstreamDisconnectBeforeRequestComplete) { } TEST_P(SslIntegrationTest, RouterDownstreamDisconnectBeforeResponseComplete) { +#ifdef __APPLE__ + // Skip this test on OS X: we can't detect the early close on OS X, and we + // won't clean up the upstream connection until it times out. See #4294. + if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { + return; + } +#endif ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { return makeSslClientConnection(false, false); }; diff --git a/test/integration/utility.cc b/test/integration/utility.cc index ff815315c6531..9e136407254fd 100644 --- a/test/integration/utility.cc +++ b/test/integration/utility.cc @@ -111,10 +111,12 @@ RawConnectionDriver::RawConnectionDriver(uint32_t port, Buffer::Instance& initia Network::Address::IpVersion version) { api_.reset(new Api::Impl(std::chrono::milliseconds(10000))); dispatcher_ = api_->allocateDispatcher(IntegrationUtil::evil_singleton_test_time_.timeSource()); + callbacks_ = std::make_unique(); client_ = dispatcher_->createClientConnection( Network::Utility::resolveUrl( fmt::format("tcp://{}:{}", Network::Test::getLoopbackAddressUrlString(version), port)), Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket(), nullptr); + client_->addConnectionCallbacks(*callbacks_); client_->addReadFilter(Network::ReadFilterSharedPtr{new ForwardingFilter(*this, data_callback)}); client_->write(initial_data, false); client_->connect(); diff --git a/test/integration/utility.h b/test/integration/utility.h index a1bd07f5ece94..04e5d650eebb3 100644 --- a/test/integration/utility.h +++ b/test/integration/utility.h @@ -62,6 +62,7 @@ class RawConnectionDriver { Network::Address::IpVersion version); ~RawConnectionDriver(); const Network::Connection& connection() { return *client_; } + bool connecting() { return callbacks_->connecting_; } void run(Event::Dispatcher::RunType run_type = Event::Dispatcher::RunType::Block); void close(); @@ -81,8 +82,17 @@ class RawConnectionDriver { ReadCallback data_callback_; }; + struct ConnectionCallbacks : public Network::ConnectionCallbacks { + void onEvent(Network::ConnectionEvent) override { connecting_ = false; } + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + bool connecting_{true}; + }; + Api::ApiPtr api_; Event::DispatcherPtr dispatcher_; + std::unique_ptr callbacks_; Network::ClientConnectionPtr client_; }; diff --git a/test/mocks/secret/mocks.h b/test/mocks/secret/mocks.h index 2637d6d850325..428f1ec28faad 100644 --- a/test/mocks/secret/mocks.h +++ b/test/mocks/secret/mocks.h @@ -32,6 +32,11 @@ class MockSecretManager : public SecretManager { TlsCertificateConfigProviderSharedPtr( const envoy::api::v2::core::ConfigSource&, const std::string&, Server::Configuration::TransportSocketFactoryContext&)); + MOCK_METHOD3(findOrCreateCertificateValidationContextProvider, + CertificateValidationContextConfigProviderSharedPtr( + const envoy::api::v2::core::ConfigSource& config_source, + const std::string& config_name, + Server::Configuration::TransportSocketFactoryContext& secret_provider_context)); }; class MockSecretCallbacks : public SecretCallbacks {