From 9cb2e20a5533bd827ae74778774668d8f8cf7d3c Mon Sep 17 00:00:00 2001 From: JimmyCYJ Date: Tue, 4 Sep 2018 21:55:55 -0700 Subject: [PATCH 01/15] Refactor SdsApi to support dynamic certificate validation context. Signed-off-by: JimmyCYJ --- include/envoy/secret/secret_manager.h | 16 +++++ source/common/secret/BUILD | 1 + source/common/secret/sds_api.cc | 42 +++++++---- source/common/secret/sds_api.h | 77 ++++++++++++++++++--- source/common/secret/secret_manager_impl.cc | 51 ++++++++++++-- source/common/secret/secret_manager_impl.h | 19 +++-- 6 files changed, 172 insertions(+), 34 deletions(-) diff --git a/include/envoy/secret/secret_manager.h b/include/envoy/secret/secret_manager.h index ca02d639643ab..93205243f046e 100644 --- a/include/envoy/secret/secret_manager.h +++ b/include/envoy/secret/secret_manager.h @@ -75,6 +75,22 @@ class SecretManager { virtual TlsCertificateConfigProviderSharedPtr findOrCreateTlsCertificateProvider( const envoy::api::v2::core::ConfigSource& config_source, const std::string& config_name, Server::Configuration::TransportSocketFactoryContext& secret_provider_context) PURE; + + /** + * Finds and returns a dynamic secret provider associated to SDS config. Create + * a new one if such provider does not exist. + * + * @param config_source a protobuf message object containing a SDS config source. + * @param config_name a name that uniquely refers to the SDS config source. + * @param secret_provider_context context that provides components for creating and initializing + * secret provider. + * @return CertificateValidationContextConfigProviderSharedPtr the dynamic certificate validation + * context secret provider. + */ + virtual CertificateValidationContextConfigProviderSharedPtr + findOrCreateCertificateValidationContextProvider( + const envoy::api::v2::core::ConfigSource& config_source, const std::string& config_name, + Server::Configuration::TransportSocketFactoryContext& secret_provider_context) PURE; }; } // namespace Secret diff --git a/source/common/secret/BUILD b/source/common/secret/BUILD index f45602e4c10c9..d0106e0eaeba5 100644 --- a/source/common/secret/BUILD +++ b/source/common/secret/BUILD @@ -52,6 +52,7 @@ envoy_cc_library( "//source/common/config:resources_lib", "//source/common/config:subscription_factory_lib", "//source/common/protobuf:utility_lib", + "//source/common/ssl:certificate_validation_context_config_impl_lib", "//source/common/ssl:tls_certificate_config_impl_lib", ], ) diff --git a/source/common/secret/sds_api.cc b/source/common/secret/sds_api.cc index f6d3b9f9db984..7b475d151ae33 100644 --- a/source/common/secret/sds_api.cc +++ b/source/common/secret/sds_api.cc @@ -7,6 +7,7 @@ #include "common/config/resources.h" #include "common/config/subscription_factory.h" #include "common/protobuf/utility.h" +#include "common/ssl/certificate_validation_context_config_impl.h" #include "common/ssl/tls_certificate_config_impl.h" namespace Envoy { @@ -17,9 +18,9 @@ SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispat Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, const envoy::api::v2::core::ConfigSource& sds_config, std::string sds_config_name, std::function destructor_cb) - : local_info_(local_info), dispatcher_(dispatcher), random_(random), stats_(stats), - cluster_manager_(cluster_manager), sds_config_(sds_config), sds_config_name_(sds_config_name), - secret_hash_(0), clean_up_(destructor_cb) { + : secret_hash_(0), local_info_(local_info), dispatcher_(dispatcher), random_(random), + stats_(stats), cluster_manager_(cluster_manager), sds_config_(sds_config), + sds_config_name_(sds_config_name), clean_up_(destructor_cb) { // TODO(JimmyCYJ): Implement chained_init_manager, so that multiple init_manager // can be chained together to behave as one init_manager. In that way, we let // two listeners which share same SdsApi to register at separate init managers, and @@ -59,15 +60,7 @@ void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) fmt::format("Unexpected SDS secret (expecting {}): {}", sds_config_name_, secret.name())); } - const uint64_t new_hash = MessageUtil::hash(secret); - if (new_hash != secret_hash_ && - secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kTlsCertificate) { - secret_hash_ = new_hash; - tls_certificate_secrets_ = - std::make_unique(secret.tls_certificate()); - - update_callback_manager_.runCallbacks(); - } + updateConfigHelper(secret); runInitializeCallbackIfAny(); } @@ -84,5 +77,30 @@ void SdsApi::runInitializeCallbackIfAny() { } } +void TlsCertificateSdsApi::updateConfigHelper(const envoy::api::v2::auth::Secret& secret) { + const uint64_t new_hash = MessageUtil::hash(secret); + if (new_hash != secret_hash_ && + secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kTlsCertificate) { + secret_hash_ = new_hash; + tls_certificate_secrets_ = + std::make_unique(secret.tls_certificate()); + + update_callback_manager_.runCallbacks(); + } +} + +void CertificateValidationContextSdsApi::updateConfigHelper( + const envoy::api::v2::auth::Secret& secret) { + const uint64_t new_hash = MessageUtil::hash(secret); + if (new_hash != secret_hash_ && + secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kValidationContext) { + secret_hash_ = new_hash; + certificate_validation_context_secrets_ = + std::make_unique(secret.validation_context()); + + update_callback_manager_.runCallbacks(); + } +} + } // namespace Secret } // namespace Envoy diff --git a/source/common/secret/sds_api.h b/source/common/secret/sds_api.h index f9211e3832b28..224994b683852 100644 --- a/source/common/secret/sds_api.h +++ b/source/common/secret/sds_api.h @@ -24,7 +24,6 @@ namespace Secret { * SDS API implementation that fetches secrets from SDS server via Subscription. */ class SdsApi : public Init::Target, - public TlsCertificateConfigProvider, public Config::SubscriptionCallbacks { public: SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, @@ -43,14 +42,10 @@ class SdsApi : public Init::Target, return MessageUtil::anyConvert(resource).name(); } - // SecretProvider - const Ssl::TlsCertificateConfig* secret() const override { - return tls_certificate_secrets_.get(); - } - - Common::CallbackHandle* addUpdateCallback(std::function callback) override { - return update_callback_manager_.add(callback); - } +protected: + // Updates local storage of dynamic secrets and invokes callbacks. + virtual void updateConfigHelper(const envoy::api::v2::auth::Secret&) {} + uint64_t secret_hash_; private: void runInitializeCallbackIfAny(); @@ -66,13 +61,73 @@ class SdsApi : public Init::Target, std::function initialize_callback_; const std::string sds_config_name_; - uint64_t secret_hash_; Cleanup clean_up_; +}; + +/** + * TlsCertificateSdsApi implementation maintains and updates dynamic TLS certificate secrets. + */ +class TlsCertificateSdsApi : public SdsApi, public TlsCertificateConfigProvider { +public: + TlsCertificateSdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, + Runtime::RandomGenerator& random, Stats::Store& stats, + Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, + const envoy::api::v2::core::ConfigSource& sds_config, + std::string sds_config_name, std::function destructor_cb) + : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, + sds_config_name, destructor_cb) {} + + // SecretProvider + const Ssl::TlsCertificateConfig* secret() const override { + return tls_certificate_secrets_.get(); + } + + Common::CallbackHandle* addUpdateCallback(std::function callback) override { + return update_callback_manager_.add(callback); + } + +private: + // SdsApi + void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; + Ssl::TlsCertificateConfigPtr tls_certificate_secrets_; Common::CallbackManager<> update_callback_manager_; }; -typedef std::unique_ptr SdsApiPtr; +/** + * CertificateValidationContextSdsApi implementation maintains and updates dynamic certificate + * validation context secrets. + */ +class CertificateValidationContextSdsApi : public SdsApi, + public CertificateValidationContextConfigProvider { +public: + CertificateValidationContextSdsApi(const LocalInfo::LocalInfo& local_info, + Event::Dispatcher& dispatcher, + Runtime::RandomGenerator& random, Stats::Store& stats, + Upstream::ClusterManager& cluster_manager, + Init::Manager& init_manager, + const envoy::api::v2::core::ConfigSource& sds_config, + std::string sds_config_name, + std::function destructor_cb) + : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, + sds_config_name, destructor_cb) {} + + // SecretProvider + const Ssl::CertificateValidationContextConfig* secret() const override { + return certificate_validation_context_secrets_.get(); + } + + Common::CallbackHandle* addUpdateCallback(std::function callback) override { + return update_callback_manager_.add(callback); + } + +private: + // SdsApi + void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; + + Ssl::CertificateValidationContextConfigPtr certificate_validation_context_secrets_; + Common::CallbackManager<> update_callback_manager_; +}; } // namespace Secret } // namespace Envoy diff --git a/source/common/secret/secret_manager_impl.cc b/source/common/secret/secret_manager_impl.cc index a310d59777ac9..09e701982b4ea 100644 --- a/source/common/secret/secret_manager_impl.cc +++ b/source/common/secret/secret_manager_impl.cc @@ -64,10 +64,10 @@ SecretManagerImpl::createInlineCertificateValidationContextProvider( certificate_validation_context); } -void SecretManagerImpl::removeDynamicSecretProvider(const std::string& map_key) { - ENVOY_LOG(debug, "Unregister secret provider. hash key: {}", map_key); +void SecretManagerImpl::removeDynamicTlsCertificateProvider(const std::string& map_key) { + ENVOY_LOG(debug, "Unregister tls certificate provider. hash key: {}", map_key); - auto num_deleted = dynamic_secret_providers_.erase(map_key); + auto num_deleted = dynamic_tls_certificate_providers_.erase(map_key); ASSERT(num_deleted == 1, ""); } @@ -76,22 +76,59 @@ TlsCertificateConfigProviderSharedPtr SecretManagerImpl::findOrCreateTlsCertific Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { const std::string map_key = sds_config_source.SerializeAsString() + config_name; - TlsCertificateConfigProviderSharedPtr secret_provider = dynamic_secret_providers_[map_key].lock(); + TlsCertificateConfigProviderSharedPtr secret_provider = + dynamic_tls_certificate_providers_[map_key].lock(); if (!secret_provider) { ASSERT(secret_provider_context.initManager() != nullptr); // SdsApi is owned by ListenerImpl and ClusterInfo which are destroyed before // SecretManagerImpl. It is safe to invoke this callback at the destructor of SdsApi. std::function unregister_secret_provider = [map_key, this]() { - removeDynamicSecretProvider(map_key); + removeDynamicTlsCertificateProvider(map_key); }; - secret_provider = std::make_shared( + secret_provider = std::make_shared( secret_provider_context.localInfo(), secret_provider_context.dispatcher(), secret_provider_context.random(), secret_provider_context.stats(), secret_provider_context.clusterManager(), *secret_provider_context.initManager(), sds_config_source, config_name, unregister_secret_provider); - dynamic_secret_providers_[map_key] = secret_provider; + dynamic_tls_certificate_providers_[map_key] = secret_provider; + } + + return secret_provider; +} + +void SecretManagerImpl::removeDynamicCertificateValidationContextProvider( + const std::string& map_key) { + ENVOY_LOG(debug, "Unregister certificate validation context provider. hash key: {}", map_key); + + auto num_deleted = dynamic_certificate_validation_context_providers_.erase(map_key); + ASSERT(num_deleted == 1, ""); +} + +CertificateValidationContextConfigProviderSharedPtr +SecretManagerImpl::findOrCreateCertificateValidationContextProvider( + const envoy::api::v2::core::ConfigSource& sds_config_source, const std::string& config_name, + Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { + const std::string map_key = sds_config_source.SerializeAsString() + config_name; + + CertificateValidationContextConfigProviderSharedPtr secret_provider = + dynamic_certificate_validation_context_providers_[map_key].lock(); + if (!secret_provider) { + ASSERT(secret_provider_context.initManager() != nullptr); + + // SdsApi is owned by ListenerImpl and ClusterInfo which are destroyed before + // SecretManagerImpl. It is safe to invoke this callback at the destructor of SdsApi. + std::function unregister_secret_provider = [map_key, this]() { + removeDynamicCertificateValidationContextProvider(map_key); + }; + + secret_provider = std::make_shared( + secret_provider_context.localInfo(), secret_provider_context.dispatcher(), + secret_provider_context.random(), secret_provider_context.stats(), + secret_provider_context.clusterManager(), *secret_provider_context.initManager(), + sds_config_source, config_name, unregister_secret_provider); + dynamic_certificate_validation_context_providers_[map_key] = secret_provider; } return secret_provider; diff --git a/source/common/secret/secret_manager_impl.h b/source/common/secret/secret_manager_impl.h index a6017ff8719c3..97c18297fdc96 100644 --- a/source/common/secret/secret_manager_impl.h +++ b/source/common/secret/secret_manager_impl.h @@ -35,9 +35,16 @@ class SecretManagerImpl : public SecretManager, Logger::Loggable @@ -47,9 +54,13 @@ class SecretManagerImpl : public SecretManager, Logger::Loggable static_certificate_validation_context_providers_; - // map hash code of SDS config source and SdsApi object. + // map hash code of SDS config source and TlsCertificateSdsApi object. std::unordered_map> - dynamic_secret_providers_; + dynamic_tls_certificate_providers_; + + // map hash code of SDS config source and CertificateValidationContextSdsApi object. + std::unordered_map> + dynamic_certificate_validation_context_providers_; }; } // namespace Secret From aa06142fffdac9dcdba97b13c67973b0b5dbf156 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Wed, 5 Sep 2018 14:43:35 -0400 Subject: [PATCH 02/15] test: Stop fake_upstream methods from accidentally succeeding (#4232) Description: FakeConnectionBase::waitForDisconnect and FakeHttpConnection::waitForNewStream were returning assertion successes when they timed out, because an AssertionResult constructed with a (non-empty) string counts as a success. Fix that. Risk Level: Low (test only) Testing: bazel test //test/... Docs Changes: n/a Release Notes: n/a Signed-off-by: Michael Behr --- test/integration/fake_upstream.cc | 4 ++-- test/integration/http_integration.cc | 7 +++++++ test/integration/ssl_integration_test.cc | 7 +++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index 5a370c556535f..42ec7c2c4e33b 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -248,7 +248,7 @@ AssertionResult FakeConnectionBase::waitForDisconnect(bool ignore_spurious_event Thread::LockGuard lock(lock_); while (shared_connection_.connected()) { if (std::chrono::steady_clock::now() >= end_time) { - return AssertionResult("Timed out waiting for disconnect."); + return AssertionFailure() << "Timed out waiting for disconnect."; } Thread::CondVar::WaitStatus status = connection_event_.waitFor(lock_, 5ms); // The default behavior of waitForDisconnect is to assume the test cleanly @@ -300,7 +300,7 @@ AssertionResult FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_d Thread::LockGuard lock(lock_); while (new_streams_.empty()) { if (std::chrono::steady_clock::now() >= end_time) { - return AssertionResult("Timed out waiting for new stream."); + return AssertionFailure() << "Timed out waiting for new stream."; } Thread::CondVar::WaitStatus status = connection_event_.waitFor(lock_, 5ms); // As with waitForDisconnect, by default, waitForNewStream returns after the next event. diff --git a/test/integration/http_integration.cc b/test/integration/http_integration.cc index 293d761d93947..36bf5c9ad2ccb 100644 --- a/test/integration/http_integration.cc +++ b/test/integration/http_integration.cc @@ -584,6 +584,13 @@ void HttpIntegrationTest::testRouterDownstreamDisconnectBeforeRequestComplete( void HttpIntegrationTest::testRouterDownstreamDisconnectBeforeResponseComplete( ConnectionCreationFunction* create_connection) { +#ifdef __APPLE__ + // Skip this test on OS X: we can't detect the early close on OS X, and we + // won't clean up the upstream connection until it times out. See #4294. + if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { + return; + } +#endif initialize(); codec_client_ = makeHttpConnection( create_connection ? ((*create_connection)()) : makeClientConnection((lookupPort("http")))); diff --git a/test/integration/ssl_integration_test.cc b/test/integration/ssl_integration_test.cc index f72f0e7c518c0..f79f6f781c769 100644 --- a/test/integration/ssl_integration_test.cc +++ b/test/integration/ssl_integration_test.cc @@ -154,6 +154,13 @@ TEST_P(SslIntegrationTest, RouterDownstreamDisconnectBeforeRequestComplete) { } TEST_P(SslIntegrationTest, RouterDownstreamDisconnectBeforeResponseComplete) { +#ifdef __APPLE__ + // Skip this test on OS X: we can't detect the early close on OS X, and we + // won't clean up the upstream connection until it times out. See #4294. + if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { + return; + } +#endif ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { return makeSslClientConnection(false, false); }; From ae6a252221dcb6c7715068e5890f95bc00437fb4 Mon Sep 17 00:00:00 2001 From: Tal Nordan Date: Wed, 5 Sep 2018 11:57:32 -0700 Subject: [PATCH 03/15] router: fix matching when all domains have wildcards (#4326) When all domains of all virtual hosts had wildcard characters and there was a default virtual host, RouteMatcher::findVirtualHost() used to erroneously ignore the wildcard suffixes. This patch fixes the issue and introduces a unit test that covers this case. Risk level: Medium. The fix is straightforward, but users who rely on the erroneous behavior might be affected. Testing: Introduced TestRoutesWithWildcardAndDefaultOnly Signed-off-by: Tal Nordan tal.nordan@solo.io --- source/common/router/config_impl.cc | 2 +- test/common/router/config_impl_test.cc | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc index d5301073fd1ac..21b891ae79b36 100644 --- a/source/common/router/config_impl.cc +++ b/source/common/router/config_impl.cc @@ -851,7 +851,7 @@ RouteConstSharedPtr VirtualHostImpl::getRouteFromEntries(const Http::HeaderMap& const VirtualHostImpl* RouteMatcher::findVirtualHost(const Http::HeaderMap& headers) const { // Fast path the case where we only have a default virtual host. - if (virtual_hosts_.empty() && default_virtual_host_) { + if (virtual_hosts_.empty() && wildcard_virtual_host_suffixes_.empty() && default_virtual_host_) { return default_virtual_host_.get(); } diff --git a/test/common/router/config_impl_test.cc b/test/common/router/config_impl_test.cc index 9951064466b8a..2b1e7fe587a1c 100644 --- a/test/common/router/config_impl_test.cc +++ b/test/common/router/config_impl_test.cc @@ -535,6 +535,31 @@ TEST(RouteMatcherTest, TestRoutes) { } } +TEST(RouteMatcherTest, TestRoutesWithWildcardAndDefaultOnly) { + std::string yaml = R"EOF( +virtual_hosts: + - name: wildcard + domains: ["*.solo.io"] + routes: + - match: { prefix: "/" } + route: { cluster: "wildcard" } + - name: default + domains: ["*"] + routes: + - match: { prefix: "/" } + route: { cluster: "default" } + )EOF"; + + const auto proto_config = parseRouteConfigurationFromV2Yaml(yaml); + NiceMock factory_context; + TestConfigImpl config(proto_config, factory_context, true); + + EXPECT_EQ("wildcard", + config.route(genHeaders("gloo.solo.io", "/", "GET"), 0)->routeEntry()->clusterName()); + EXPECT_EQ("default", + config.route(genHeaders("example.com", "/", "GET"), 0)->routeEntry()->clusterName()); +} + TEST(RouteMatcherTest, TestRoutesWithInvalidRegex) { std::string invalid_route = R"EOF( virtual_hosts: From e34dcd62a2788e48956f8f9320565b693f4afbbc Mon Sep 17 00:00:00 2001 From: Greg Greenway Date: Wed, 5 Sep 2018 12:27:51 -0700 Subject: [PATCH 04/15] Fix crash in tcp_proxy (#4323) * Fix crash in tcp_proxy. Closing the upstream connection is not safe from the Filter destructor, because it triggers events back into the downstream connection, which is partially destructed. Ensure that the upstream connection is closed before the destructor is called. Fixes #4310 Signed-off-by: Greg Greenway --- source/common/tcp_proxy/tcp_proxy.cc | 31 ++++++++++++------- test/common/http/conn_manager_impl_test.cc | 15 +++++++++ .../network/filter_manager_impl_test.cc | 2 ++ test/common/tcp_proxy/tcp_proxy_test.cc | 31 ++++++++++++++++++- test/integration/fake_upstream.cc | 2 +- 5 files changed, 67 insertions(+), 14 deletions(-) diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index eb7d4196d9107..e4517f158111e 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -135,13 +135,8 @@ Filter::~Filter() { access_log->log(nullptr, nullptr, nullptr, getRequestInfo()); } - if (upstream_handle_) { - upstream_handle_->cancel(); - } - - if (upstream_conn_data_) { - upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); - } + ASSERT(upstream_handle_ == nullptr); + ASSERT(upstream_conn_data_ == nullptr); } TcpProxyStats Config::SharedConfig::generateStats(Stats::Scope& scope) { @@ -412,17 +407,29 @@ void Filter::onDownstreamEvent(Network::ConnectionEvent event) { if (event == Network::ConnectionEvent::RemoteClose) { upstream_conn_data_->connection().close(Network::ConnectionCloseType::FlushWrite); - if (upstream_conn_data_ != nullptr && - upstream_conn_data_->connection().state() != Network::Connection::State::Closed) { - config_->drainManager().add(config_->sharedConfig(), std::move(upstream_conn_data_), - std::move(upstream_callbacks_), std::move(idle_timer_), - read_callbacks_->upstreamHost()); + // Events raised from the previous line may cause upstream_conn_data_ to be NULL if + // it was able to immediately flush all data. + + if (upstream_conn_data_ != nullptr) { + if (upstream_conn_data_->connection().state() != Network::Connection::State::Closed) { + config_->drainManager().add(config_->sharedConfig(), std::move(upstream_conn_data_), + std::move(upstream_callbacks_), std::move(idle_timer_), + read_callbacks_->upstreamHost()); + } else { + upstream_conn_data_.reset(); + } } } else if (event == Network::ConnectionEvent::LocalClose) { upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); upstream_conn_data_.reset(); disableIdleTimer(); } + } else if (upstream_handle_) { + if (event == Network::ConnectionEvent::LocalClose || + event == Network::ConnectionEvent::RemoteClose) { + upstream_handle_->cancel(); + upstream_handle_ = nullptr; + } } } diff --git a/test/common/http/conn_manager_impl_test.cc b/test/common/http/conn_manager_impl_test.cc index 90e6e21f45a30..266a705a5fb97 100644 --- a/test/common/http/conn_manager_impl_test.cc +++ b/test/common/http/conn_manager_impl_test.cc @@ -1705,6 +1705,10 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { Buffer::OwnedImpl fake_input("1234"); conn_manager_->onData(fake_input, false); + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks = nullptr; + EXPECT_CALL(*conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) { upstream_callbacks = &cb; })); conn_pool_.host_->hostname_ = "newhost"; conn_pool_.poolReady(upstream_conn_); @@ -1714,6 +1718,7 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { EXPECT_EQ(1U, stats_.named_.downstream_cx_websocket_total_.value()); EXPECT_EQ(0U, stats_.named_.downstream_cx_http1_active_.value()); + upstream_callbacks->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); EXPECT_EQ(0U, stats_.named_.downstream_cx_websocket_active_.value()); @@ -1753,8 +1758,13 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyData) { EXPECT_CALL(upstream_conn_, write(_, false)); EXPECT_CALL(upstream_conn_, write(BufferEqual(&early_data), false)); EXPECT_CALL(filter_callbacks_.connection_, readDisable(false)); + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks = nullptr; + EXPECT_CALL(*conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) { upstream_callbacks = &cb; })); conn_pool_.poolReady(upstream_conn_); + upstream_callbacks->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } @@ -1828,7 +1838,12 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyEndStream) { EXPECT_CALL(upstream_conn_, write(_, false)); EXPECT_CALL(upstream_conn_, write(_, true)).Times(0); + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks = nullptr; + EXPECT_CALL(*conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) { upstream_callbacks = &cb; })); conn_pool_.poolReady(upstream_conn_); + upstream_callbacks->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index 06837dfbd450c..8aeaf11a978b9 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -214,6 +214,8 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { EXPECT_CALL(upstream_connection, write(BufferEqual(&buffer), _)); read_buffer_.add("hello"); manager.onRead(); + + connection.raiseEvent(ConnectionEvent::RemoteClose); } } // namespace Network diff --git a/test/common/tcp_proxy/tcp_proxy_test.cc b/test/common/tcp_proxy/tcp_proxy_test.cc index f447f62b08e3e..1df01629ffa81 100644 --- a/test/common/tcp_proxy/tcp_proxy_test.cc +++ b/test/common/tcp_proxy/tcp_proxy_test.cc @@ -348,6 +348,12 @@ class TcpProxyTest : public testing::Test { .WillByDefault(SaveArg<0>(&access_log_data_)); } + ~TcpProxyTest() { + if (filter_ != nullptr) { + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); + } + } + void configure(const envoy::config::filter::network::tcp_proxy::v2::TcpProxy& config) { config_.reset(new Config(config, factory_context_)); } @@ -734,6 +740,22 @@ TEST_F(TcpProxyTest, DisconnectBeforeData) { filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); } +// Test that if the downstream connection is closed before the upstream connection +// is established, the upstream connection is cancelled. +TEST_F(TcpProxyTest, RemoteClosetBeforeUpstreamConnected) { + setup(1); + EXPECT_CALL(*conn_pool_handles_.at(0), cancel()); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); +} + +// Test that if the downstream connection is closed before the upstream connection +// is established, the upstream connection is cancelled. +TEST_F(TcpProxyTest, LocalClosetBeforeUpstreamConnected) { + setup(1); + EXPECT_CALL(*conn_pool_handles_.at(0), cancel()); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::LocalClose); +} + TEST_F(TcpProxyTest, UpstreamConnectFailure) { setup(1, accessLogConfig("%RESPONSE_FLAGS%")); @@ -873,6 +895,7 @@ TEST_F(TcpProxyTest, IdleTimeoutWithOutstandingDataFlushed) { TEST_F(TcpProxyTest, AccessLogUpstreamHost) { setup(1, accessLogConfig("%UPSTREAM_HOST% %UPSTREAM_CLUSTER%")); raiseEventUpstreamConnected(0); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_EQ(access_log_data_, "127.0.0.1:80 fake_cluster"); } @@ -881,6 +904,7 @@ TEST_F(TcpProxyTest, AccessLogUpstreamHost) { TEST_F(TcpProxyTest, AccessLogUpstreamLocalAddress) { setup(1, accessLogConfig("%UPSTREAM_LOCAL_ADDRESS%")); raiseEventUpstreamConnected(0); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_EQ(access_log_data_, "2.2.2.2:50000"); } @@ -893,6 +917,7 @@ TEST_F(TcpProxyTest, AccessLogDownstreamAddress) { filter_callbacks_.connection_.remote_address_ = Network::Utility::resolveUrl("tcp://1.1.1.1:40000"); setup(1, accessLogConfig("%DOWNSTREAM_REMOTE_ADDRESS_WITHOUT_PORT% %DOWNSTREAM_LOCAL_ADDRESS%")); + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_EQ(access_log_data_, "1.1.1.1 1.1.1.2:20000"); } @@ -1075,6 +1100,9 @@ TEST_F(TcpProxyRoutingTest, NonRoutableConnection) { EXPECT_EQ(total_cx + 1, config_->stats().downstream_cx_total_.value()); EXPECT_EQ(non_routable_cx + 1, config_->stats().downstream_cx_no_route_.value()); + + // Cleanup + filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); } TEST_F(TcpProxyRoutingTest, RoutableConnection) { @@ -1087,7 +1115,8 @@ TEST_F(TcpProxyRoutingTest, RoutableConnection) { connection_.local_address_ = std::make_shared("1.2.3.4", 9999); // Expect filter to try to open a connection to specified cluster. - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)); + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(nullptr)); filter_->onNewConnection(); diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index 42ec7c2c4e33b..571184ffa75bf 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -536,7 +536,7 @@ AssertionResult FakeRawConnection::write(const std::string& data, bool end_strea Network::FilterStatus FakeRawConnection::ReadFilter::onData(Buffer::Instance& data, bool end_stream) { Thread::LockGuard lock(parent_.lock_); - ENVOY_LOG(debug, "got {} bytes", data.length()); + ENVOY_LOG(debug, "got {} bytes, end_stream {}", data.length(), end_stream); parent_.data_.append(data.toString()); parent_.half_closed_ = end_stream; data.drain(data.length()); From f936fc60f488cfae07f5e5d20d7381f0f23482fe Mon Sep 17 00:00:00 2001 From: Alex Konradi Date: Wed, 5 Sep 2018 15:54:10 -0400 Subject: [PATCH 05/15] ssl: serialize accesses to SSL socket factory contexts (#4345) Description: The ssl_ctx_ fields of the ServerSslSocketFactory and ClientSslSocketFactory are accessed and mutated from different threads without external serialization. This is a bug since instances of std::shared_ptr are not thread-safe (though different instances pointing to the same object are). This patch fixes the bug by using a mutex to serialize accesses. Risk Level: Low Testing: ran test suite Docs Changes: n/a Release Notes: n/a --- source/common/ssl/BUILD | 6 +++++- source/common/ssl/ssl_socket.cc | 22 ++++++++++++++++++---- source/common/ssl/ssl_socket.h | 7 +++++-- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/source/common/ssl/BUILD b/source/common/ssl/BUILD index f2e17e9726c54..9cc5179d0e998 100644 --- a/source/common/ssl/BUILD +++ b/source/common/ssl/BUILD @@ -12,7 +12,10 @@ envoy_cc_library( name = "ssl_socket_lib", srcs = ["ssl_socket.cc"], hdrs = ["ssl_socket.h"], - external_deps = ["ssl"], + external_deps = [ + "abseil_synchronization", + "ssl", + ], deps = [ ":context_config_lib", ":context_lib", @@ -23,6 +26,7 @@ envoy_cc_library( "//source/common/common:assert_lib", "//source/common/common:empty_string", "//source/common/common:minimal_logger_lib", + "//source/common/common:thread_annotations", "//source/common/http:headers_lib", ], ) diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index 5ba3ac9eb7dde..757c0d9a87fe3 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -422,7 +422,11 @@ Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() cons // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. - auto ssl_ctx = ssl_ctx_; + ClientContextSharedPtr ssl_ctx; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + ssl_ctx = ssl_ctx_; + } if (ssl_ctx) { return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Client); } else { @@ -436,7 +440,10 @@ bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } void ClientSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); - ssl_ctx_ = manager_.createSslClientContext(stats_scope_, *config_); + { + absl::WriterMutexLock l(&ssl_ctx_mu_); + ssl_ctx_ = manager_.createSslClientContext(stats_scope_, *config_); + } stats_.ssl_context_update_by_sds_.inc(); } @@ -454,7 +461,11 @@ Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() cons // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. - auto ssl_ctx = ssl_ctx_; + ServerContextSharedPtr ssl_ctx; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + ssl_ctx = ssl_ctx_; + } if (ssl_ctx) { return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Server); } else { @@ -468,7 +479,10 @@ bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } void ServerSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); - ssl_ctx_ = manager_.createSslServerContext(stats_scope_, *config_, server_names_); + { + absl::WriterMutexLock l(&ssl_ctx_mu_); + ssl_ctx_ = manager_.createSslServerContext(stats_scope_, *config_, server_names_); + } stats_.ssl_context_update_by_sds_.inc(); } diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index 951af874a655f..1cdfce9d1711b 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -12,6 +12,7 @@ #include "common/common/logger.h" #include "common/ssl/context_impl.h" +#include "absl/synchronization/mutex.h" #include "openssl/ssl.h" namespace Envoy { @@ -101,7 +102,8 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, Stats::Scope& stats_scope_; SslSocketFactoryStats stats_; ClientContextConfigPtr config_; - ClientContextSharedPtr ssl_ctx_; + mutable absl::Mutex ssl_ctx_mu_; + ClientContextSharedPtr ssl_ctx_ GUARDED_BY(ssl_ctx_mu_); }; class ServerSslSocketFactory : public Network::TransportSocketFactory, @@ -123,7 +125,8 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, SslSocketFactoryStats stats_; ServerContextConfigPtr config_; const std::vector server_names_; - ServerContextSharedPtr ssl_ctx_; + mutable absl::Mutex ssl_ctx_mu_; + ServerContextSharedPtr ssl_ctx_ GUARDED_BY(ssl_ctx_mu_); }; } // namespace Ssl From 763f2a740e10ac705e98d01567f0506104fb8135 Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Wed, 5 Sep 2018 14:10:04 -0700 Subject: [PATCH 06/15] thrift: refactor Thrift router to allow protocol upgrade (#4286) Modifies the Thrift router to allow protocols to upgrade on initial data from the downstream connection and upgrade an upstream connection before transmitting a request. As part of this work, Transport and Protocol objects are reused across downstream connections and for an upstream's request and response. *Risk Level*: low, upgrade path unused *Testing*: unit tests *Docs Changes*: n/a *Release Notes*: n/a Signed-off-by: Stephan Zuercher --- .../filters/network/thrift_proxy/BUILD | 34 ++ .../filters/network/thrift_proxy/config.cc | 4 - .../filters/network/thrift_proxy/config.h | 6 +- .../network/thrift_proxy/conn_manager.cc | 43 +- .../network/thrift_proxy/conn_manager.h | 22 +- .../filters/network/thrift_proxy/conn_state.h | 48 ++ .../filters/network/thrift_proxy/decoder.cc | 35 +- .../filters/network/thrift_proxy/decoder.h | 19 +- .../network/thrift_proxy/filters/filter.h | 6 +- .../filters/network/thrift_proxy/metadata.h | 9 + .../filters/network/thrift_proxy/protocol.h | 102 +++- .../network/thrift_proxy/protocol_converter.h | 6 +- .../filters/network/thrift_proxy/router/BUILD | 9 +- .../network/thrift_proxy/router/router.h | 2 + .../thrift_proxy/router/router_impl.cc | 77 ++- .../network/thrift_proxy/router/router_impl.h | 7 +- .../network/thrift_proxy/thrift_object.h | 247 +++++++++ .../thrift_proxy/thrift_object_impl.cc | 394 ++++++++++++++ .../network/thrift_proxy/thrift_object_impl.h | 262 ++++++++++ .../filters/network/thrift_proxy/BUILD | 13 + .../network/thrift_proxy/conn_manager_test.cc | 123 ++++- .../network/thrift_proxy/decoder_test.cc | 167 +++--- .../filters/network/thrift_proxy/mocks.cc | 4 + .../filters/network/thrift_proxy/mocks.h | 18 +- .../network/thrift_proxy/router_test.cc | 144 ++++- .../thrift_proxy/thrift_object_impl_test.cc | 494 ++++++++++++++++++ 26 files changed, 2070 insertions(+), 225 deletions(-) create mode 100644 source/extensions/filters/network/thrift_proxy/conn_state.h create mode 100644 source/extensions/filters/network/thrift_proxy/thrift_object.h create mode 100644 source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc create mode 100644 source/extensions/filters/network/thrift_proxy/thrift_object_impl.h create mode 100644 test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 25c6dfda0dda7..e2d1789976b63 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -105,6 +105,7 @@ envoy_cc_library( external_deps = ["abseil_optional"], deps = [ ":thrift_lib", + "//include/envoy/buffer:buffer_interface", "//source/common/common:macros", ], ) @@ -128,14 +129,18 @@ envoy_cc_library( ], external_deps = ["abseil_optional"], deps = [ + ":conn_state_lib", ":decoder_events_lib", ":metadata_lib", ":thrift_lib", + ":thrift_object_interface", + ":transport_interface", "//include/envoy/buffer:buffer_interface", "//include/envoy/registry", "//source/common/common:assert_lib", "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) @@ -221,6 +226,14 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "conn_state_lib", + hdrs = ["conn_state.h"], + deps = [ + "//include/envoy/tcp:conn_pool_interface", + ], +) + envoy_cc_library( name = "thrift_lib", hdrs = ["thrift.h"], @@ -230,6 +243,27 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "thrift_object_interface", + hdrs = ["thrift_object.h"], + deps = [ + "//include/envoy/buffer:buffer_interface", + ], +) + +envoy_cc_library( + name = "thrift_object_lib", + srcs = ["thrift_object_impl.cc"], + hdrs = ["thrift_object_impl.h"], + deps = [ + ":decoder_lib", + ":thrift_lib", + ":thrift_object_interface", + ":unframed_transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "auto_transport_lib", srcs = [ diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index bba8a035a25bf..359cf7f77ed5c 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -142,10 +142,6 @@ void ConfigImpl::createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& c } } -DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { - return std::make_unique(createTransport(), createProtocol(), callbacks); -} - TransportPtr ConfigImpl::createTransport() { return NamedTransportConfigFactory::getFactory(transport_).createTransport(); } diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 0b32e7964c4ea..71c2d1c580c91 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -76,13 +76,11 @@ class ConfigImpl : public Config, // Config ThriftFilterStats& stats() override { return stats_; } ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } - DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; + TransportPtr createTransport() override; + ProtocolPtr createProtocol() override; Router::Config& routerConfig() override { return *this; } private: - TransportPtr createTransport(); - ProtocolPtr createProtocol(); - Server::Configuration::FactoryContext& context_; const std::string stats_prefix_; ThriftFilterStats stats_; diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index 6b83435351b9d..de34169fc385f 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -16,7 +16,9 @@ namespace NetworkFilters { namespace ThriftProxy { ConnectionManager::ConnectionManager(Config& config) - : config_(config), stats_(config_.stats()), decoder_(config_.createDecoder(*this)) {} + : config_(config), stats_(config_.stats()), transport_(config.createTransport()), + protocol_(config.createProtocol()), + decoder_(std::make_unique(*transport_, *protocol_, *this)) {} ConnectionManager::~ConnectionManager() {} @@ -67,22 +69,14 @@ void ConnectionManager::dispatch() { } void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectResponse& response) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - ProtocolType proto_type = decoder_->protocolType(); - ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(proto_type).createProtocol(); Buffer::OwnedImpl buffer; - response.encode(metadata, *proto, buffer); - - // Same logic as protocol above. - TransportPtr transport = - NamedTransportConfigFactory::getFactory(decoder_->transportType()).createTransport(); + response.encode(metadata, *protocol_, buffer); Buffer::OwnedImpl response_buffer; - metadata.setProtocol(proto_type); - transport->encodeFrame(response_buffer, metadata, buffer); + metadata.setProtocol(protocol_->type()); + transport_->encodeFrame(response_buffer, metadata, buffer); read_callbacks_->connection().write(response_buffer, false); } @@ -230,12 +224,30 @@ FilterStatus ConnectionManager::ActiveRpc::transportEnd() { break; } - return decoder_filter_->transportEnd(); + FilterStatus status = event_handler_->transportEnd(); + + if (metadata_->isProtocolUpgradeMessage()) { + ENVOY_CONN_LOG(error, "thrift: sending protocol upgrade response", + parent_.read_callbacks_->connection()); + sendLocalReply(*parent_.protocol_->upgradeResponse(*upgrade_handler_)); + } + + return status; } FilterStatus ConnectionManager::ActiveRpc::messageBegin(MessageMetadataSharedPtr metadata) { metadata_ = metadata; + if (metadata_->isProtocolUpgradeMessage()) { + ASSERT(parent_.protocol_->supportsUpgrade()); + + ENVOY_CONN_LOG(error, "thrift: decoding protocol upgrade request", + parent_.read_callbacks_->connection()); + upgrade_handler_ = parent_.protocol_->upgradeRequestDecoder(); + ASSERT(upgrade_handler_ != nullptr); + event_handler_ = upgrade_handler_.get(); + } + return event_handler_->messageBegin(metadata); } @@ -282,11 +294,10 @@ void ConnectionManager::ActiveRpc::sendLocalReply(const DirectResponse& response parent_.doDeferredRpcDestroy(*this); } -void ConnectionManager::ActiveRpc::startUpstreamResponse(TransportType transport_type, - ProtocolType protocol_type) { +void ConnectionManager::ActiveRpc::startUpstreamResponse(Transport& transport, Protocol& protocol) { ASSERT(response_decoder_ == nullptr); - response_decoder_ = std::make_unique(*this, transport_type, protocol_type); + response_decoder_ = std::make_unique(*this, transport, protocol); } bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index 5d4f4f9cf1151..6432a51377960 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -31,7 +31,8 @@ class Config { virtual ThriftFilters::FilterChainFactory& filterFactory() PURE; virtual ThriftFilterStats& stats() PURE; - virtual DecoderPtr createDecoder(DecoderCallbacks& callbacks) PURE; + virtual TransportPtr createTransport() PURE; + virtual ProtocolPtr createProtocol() PURE; virtual Router::Config& routerConfig() PURE; }; @@ -74,18 +75,10 @@ class ConnectionManager : public Network::ReadFilter, struct ActiveRpc; struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { - ResponseDecoder(ActiveRpc& parent, TransportType transport_type, ProtocolType protocol_type) - : parent_(parent), - decoder_(std::make_unique( - NamedTransportConfigFactory::getFactory(transport_type).createTransport(), - NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(), *this)), + ResponseDecoder(ActiveRpc& parent, Transport& transport, Protocol& protocol) + : parent_(parent), decoder_(std::make_unique(transport, protocol, *this)), complete_(false), first_reply_field_(false) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - initProtocolConverter( - NamedProtocolConfigFactory::getFactory(parent_.parent_.decoder_->protocolType()) - .createProtocol(), - parent_.response_buffer_); + initProtocolConverter(*parent_.parent_.protocol_, parent_.response_buffer_); } bool onData(Buffer::Instance& data); @@ -149,7 +142,7 @@ class ConnectionManager : public Network::ReadFilter, return parent_.decoder_->protocolType(); } void sendLocalReply(const DirectResponse& response) override; - void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; + void startUpstreamResponse(Transport& transport, Protocol& protocol) override; bool upstreamData(Buffer::Instance& buffer) override; void resetDownstreamConnection() override; @@ -170,6 +163,7 @@ class ConnectionManager : public Network::ReadFilter, uint64_t stream_id_; MessageMetadataSharedPtr metadata_; ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + DecoderEventHandlerSharedPtr upgrade_handler_; ResponseDecoderPtr response_decoder_; absl::optional cached_route_; Buffer::OwnedImpl response_buffer_; @@ -188,6 +182,8 @@ class ConnectionManager : public Network::ReadFilter, Network::ReadFilterCallbacks* read_callbacks_{}; + TransportPtr transport_; + ProtocolPtr protocol_; DecoderPtr decoder_; std::list rpcs_; Buffer::OwnedImpl request_buffer_; diff --git a/source/extensions/filters/network/thrift_proxy/conn_state.h b/source/extensions/filters/network/thrift_proxy/conn_state.h new file mode 100644 index 0000000000000..f5db30d458461 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_state.h @@ -0,0 +1,48 @@ +#pragma once + +#include "envoy/tcp/conn_pool.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ThriftConnectionState tracks thrift-related connection state for pooled connections. + */ +class ThriftConnectionState : public Tcp::ConnectionPool::ConnectionState { +public: + /** + * @return true if this upgrade has been attempted on this connection. + */ + bool upgradeAttempted() const { return upgrade_attempted_; } + /** + * @return true if this connection has been upgraded + */ + bool isUpgraded() const { return upgraded_; } + + /** + * Marks the connection as successfully upgraded. + */ + void markUpgraded() { + upgrade_attempted_ = true; + upgraded_ = true; + } + + /** + * Marks the connection as not upgraded. + */ + void markUpgradeFailed() { + upgrade_attempted_ = true; + upgraded_ = false; + } + +private: + bool upgrade_attempted_{false}; + bool upgraded_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index c6d4b470ea890..bc56bffea0b00 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -334,6 +334,9 @@ ProtocolState DecoderStateMachine::popReturnState() { ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { while (state_ != ProtocolState::Done) { + ENVOY_LOG(trace, "thrift: state {}, {} bytes available", ProtocolStateNameValues::name(state_), + buffer.length()); + DecoderStatus s = handleState(buffer); if (s.next_state_ == ProtocolState::WaitForData) { return ProtocolState::WaitForData; @@ -350,8 +353,8 @@ ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { return state_; } -Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks) - : transport_(std::move(transport)), protocol_(std::move(protocol)), callbacks_(callbacks) {} +Decoder::Decoder(Transport& transport, Protocol& protocol, DecoderCallbacks& callbacks) + : transport_(transport), protocol_(protocol), callbacks_(callbacks) {} void Decoder::complete() { request_.reset(); @@ -377,22 +380,22 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { metadata_ = std::make_shared(); } - if (!transport_->decodeFrameStart(data, *metadata_)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); + if (!transport_.decodeFrameStart(data, *metadata_)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_.name()); buffer_underflow = true; return FilterStatus::Continue; } - ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); + ENVOY_LOG(debug, "thrift: {} transport started", transport_.name()); if (metadata_->hasProtocol()) { - if (protocol_->type() == ProtocolType::Auto) { - protocol_->setType(metadata_->protocol()); - ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_->name(), - protocol_->name()); - } else if (metadata_->protocol() != protocol_->type()) { + if (protocol_.type() == ProtocolType::Auto) { + protocol_.setType(metadata_->protocol()); + ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_.name(), + protocol_.name()); + } else if (metadata_->protocol() != protocol_.type()) { throw EnvoyException(fmt::format("transport reports protocol {}, but configured for {}", ProtocolNames::get().fromType(metadata_->protocol()), - ProtocolNames::get().fromType(protocol_->type()))); + ProtocolNames::get().fromType(protocol_.type()))); } } if (metadata_->hasAppException()) { @@ -406,7 +409,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { request_ = std::make_unique(callbacks_.newDecoderEventHandler()); frame_started_ = true; state_machine_ = - std::make_unique(*protocol_, metadata_, request_->handler_); + std::make_unique(protocol_, metadata_, request_->handler_); if (request_->handler_.transportBegin(metadata_) == FilterStatus::StopIteration) { return FilterStatus::StopIteration; @@ -415,7 +418,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ASSERT(state_machine_ != nullptr); - ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), + ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_.name(), ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); ProtocolState rv = state_machine_->run(data); @@ -431,8 +434,8 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ASSERT(rv == ProtocolState::Done); // Message complete, decode end of frame. - if (!transport_->decodeFrameEnd(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); + if (!transport_.decodeFrameEnd(data)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_.name()); buffer_underflow = true; return FilterStatus::Continue; } @@ -440,7 +443,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { frame_ended_ = true; metadata_.reset(); - ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + ENVOY_LOG(debug, "thrift: {} transport ended", transport_.name()); if (request_->handler_.transportEnd() == FilterStatus::StopIteration) { return FilterStatus::StopIteration; } diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index cda5d75b4a8b7..e2886aedebd60 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -2,7 +2,6 @@ #include "envoy/buffer/buffer.h" -#include "common/buffer/buffer_impl.h" #include "common/common/assert.h" #include "common/common/logger.h" @@ -61,7 +60,7 @@ class ProtocolStateNameValues { * DecoderStateMachine is the Thrift message state machine as described in * source/extensions/filters/network/thrift_proxy/docs. */ -class DecoderStateMachine { +class DecoderStateMachine : public Logger::Loggable { public: DecoderStateMachine(Protocol& proto, MessageMetadataSharedPtr& metadata, DecoderEventHandler& handler) @@ -183,15 +182,15 @@ class DecoderCallbacks { }; /** - * Decoder encapsulates a configured TransportPtr and ProtocolPtr. + * Decoder encapsulates a configured Transport and Protocol and provides the ability to decode + * Thrift messages. */ class Decoder : public Logger::Loggable { public: - Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks); - Decoder(TransportType transport_type, ProtocolType protocol_type, DecoderCallbacks& callbacks); + Decoder(Transport& transport, Protocol& protocol, DecoderCallbacks& callbacks); /** - * Drains data from the given buffer while executing a DecoderStateMachine over the data. + * Drains data from the given buffer while executing a state machine over the data. * * @param data a Buffer containing Thrift protocol data * @param buffer_underflow bool set to true if more data is required to continue decoding @@ -201,8 +200,8 @@ class Decoder : public Logger::Loggable { */ FilterStatus onData(Buffer::Instance& data, bool& buffer_underflow); - TransportType transportType() { return transport_->type(); } - ProtocolType protocolType() { return protocol_->type(); } + TransportType transportType() { return transport_.type(); } + ProtocolType protocolType() { return protocol_.type(); } private: struct ActiveRequest { @@ -214,8 +213,8 @@ class Decoder : public Logger::Loggable { void complete(); - TransportPtr transport_; - ProtocolPtr protocol_; + Transport& transport_; + Protocol& protocol_; DecoderCallbacks& callbacks_; ActiveRequestPtr request_; MessageMetadataSharedPtr metadata_; diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index 304db348d9e29..4183514455b91 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -68,10 +68,10 @@ class DecoderFilterCallbacks { /** * Indicates the start of an upstream response. May only be called once. - * @param transport_type TransportType the upstream is using - * @param protocol_type ProtocolType the upstream is using + * @param transport the transport used by the upstream response + * @param protocol the protocol used by the upstream response */ - virtual void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) PURE; + virtual void startUpstreamResponse(Transport& transport, Protocol& protocol) PURE; /** * Called with upstream response data. diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h index e9498d88eef43..bb659e5afdc47 100644 --- a/source/extensions/filters/network/thrift_proxy/metadata.h +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -4,8 +4,11 @@ #include #include +#include #include +#include "envoy/buffer/buffer.h" + #include "common/common/macros.h" #include "common/http/header_map_impl.h" @@ -62,6 +65,11 @@ class MessageMetadata { AppExceptionType appExceptionType() const { return app_ex_type_.value(); } const std::string& appExceptionMessage() const { return app_ex_msg_.value(); } + bool isProtocolUpgradeMessage() const { return protocol_upgrade_message_; } + void setProtocolUpgradeMessage(bool upgrade_message) { + protocol_upgrade_message_ = upgrade_message; + } + private: absl::optional frame_size_{}; absl::optional proto_{}; @@ -71,6 +79,7 @@ class MessageMetadata { Http::HeaderMapImpl headers_; absl::optional app_ex_type_; absl::optional app_ex_msg_; + bool protocol_upgrade_message_{false}; }; typedef std::shared_ptr MessageMetadataSharedPtr; diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 0e4066021ce9b..316ca0e79df3a 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -11,8 +11,12 @@ #include "common/config/utility.h" #include "common/singleton/const_singleton.h" +#include "extensions/filters/network/thrift_proxy/conn_state.h" +#include "extensions/filters/network/thrift_proxy/decoder_events.h" #include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/thrift.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" +#include "extensions/filters/network/thrift_proxy/transport.h" #include "absl/strings/string_view.h" @@ -21,6 +25,9 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +class DirectResponse; +typedef std::unique_ptr DirectResponsePtr; + /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -394,10 +401,86 @@ class Protocol { * @param value std::string to write */ virtual void writeBinary(Buffer::Instance& buffer, const std::string& value) PURE; + + /** + * Indicates whether a protocol uses start-of-connection messages to negotiate protocol options. + * If this method returns true, the Protocol must invoke setProtocolUpgradeMessage during + * readMessageBegin if it detects an upgrade request. + * + * @return true for protocols that exchange messages at the start of a connection to negotiate + * protocol upgrade (or options) + */ + virtual bool supportsUpgrade() { return false; } + + /** + * Creates an opaque DecoderEventHandlerSharedPtr that can decode a downstream client's upgrade + * request. When the request is complete, the decoder is passed back to writeUpgradeResponse + * to allow the Protocol to update its internal state and generate a response to the request. + * + * @return a DecoderEventHandlerSharedPtr that decodes a downstream client's upgrade request + */ + virtual DecoderEventHandlerSharedPtr upgradeRequestDecoder() { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + + /** + * Writes a response to a downstream client's upgrade request. + * @param decoder DecoderEventHandlerSharedPtr created by upgradeRequestDecoder + * @return DirectResponsePtr containing an upgrade response + */ + virtual DirectResponsePtr upgradeResponse(const DecoderEventHandler& decoder) { + UNREFERENCED_PARAMETER(decoder); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + + /** + * Checks whether a given upstream connection can be upgraded and generates an upgrade request + * message. If this method returns a ThriftObject it will be used to decode the upstream's next + * response. + * + * @param transport the Transport to use for decoding the response + * @param state ThriftConnectionState tracking whether upgrade has already been performed + * @param buffer Buffer::Instance to modify with an upgrade request + * @return a ThriftObject capable of decoding an upgrade response or nullptr if upgrade was + * already completed (successfully or not) + */ + virtual ThriftObjectPtr attemptUpgrade(Transport& transport, ThriftConnectionState& state, + Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(transport); + UNREFERENCED_PARAMETER(state); + UNREFERENCED_PARAMETER(buffer); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + + /** + * Completes an upgrade previously started via attemptUpgrade. + * @param response ThriftObject created by attemptUpgrade, after the response has completed + * decoding + */ + virtual void completeUpgrade(ThriftConnectionState& state, ThriftObject& response) { + UNREFERENCED_PARAMETER(state); + UNREFERENCED_PARAMETER(response); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } }; typedef std::unique_ptr ProtocolPtr; +/** + * A DirectResponse manipulates a Protocol to directly create a Thrift response message. + */ +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param metadata the MessageMetadata for the request that generated this response + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(MessageMetadata& metadata, Protocol& proto, + Buffer::Instance& buffer) const PURE; +}; + /** * Implemented by each Thrift protocol and registered via Registry::registerFactory or the * convenience class RegisterFactory. @@ -444,25 +527,6 @@ template class ProtocolFactoryBase : public NamedProtocolCo const std::string name_; }; -/** - * A DirectResponse manipulates a Protocol to directly create a Thrift response message. - */ -class DirectResponse { -public: - virtual ~DirectResponse() {} - - /** - * Encodes the response via the given Protocol. - * @param metadata the MessageMetadata for the request that generated this response - * @param proto the Protocol to be used for message encoding - * @param buffer the Buffer into which the message should be encoded - */ - virtual void encode(MessageMetadata& metadata, Protocol& proto, - Buffer::Instance& buffer) const PURE; -}; - -typedef std::unique_ptr DirectResponsePtr; - } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index 8a3fde7894505..22a868f1a40af 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -19,8 +19,8 @@ class ProtocolConverter : public virtual DecoderEventHandler { ProtocolConverter() {} virtual ~ProtocolConverter() {} - void initProtocolConverter(ProtocolPtr&& proto, Buffer::Instance& buffer) { - proto_ = std::move(proto); + void initProtocolConverter(Protocol& proto, Buffer::Instance& buffer) { + proto_ = &proto; buffer_ = &buffer; } @@ -125,7 +125,7 @@ class ProtocolConverter : public virtual DecoderEventHandler { ProtocolType protocolType() const { return proto_->type(); } private: - ProtocolPtr proto_; + Protocol* proto_; Buffer::Instance* buffer_{}; }; diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD index e5051335b04e9..ba6cf07cc045b 100644 --- a/source/extensions/filters/network/thrift_proxy/router/BUILD +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -26,7 +26,9 @@ envoy_cc_library( name = "router_interface", hdrs = ["router.h"], external_deps = ["abseil_optional"], - deps = [], + deps = [ + "//source/extensions/filters/network/thrift_proxy:metadata_lib", + ], ) envoy_cc_library( @@ -46,8 +48,9 @@ envoy_cc_library( "//source/extensions/filters/network/thrift_proxy:app_exception_lib", "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", - "//source/extensions/filters/network/thrift_proxy:protocol_lib", - "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:thrift_object_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", "@envoy_api//envoy/config/filter/network/thrift_proxy/v2alpha1:thrift_proxy_cc", ], diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h index 7f995e5e25a1d..b456a24ae7da1 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router.h +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -3,6 +3,8 @@ #include #include +#include "extensions/filters/network/thrift_proxy/metadata.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index 57e8be9f6caf1..a854b7432aa87 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -213,13 +213,12 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { FilterStatus Router::messageEnd() { ProtocolConverter::messageEnd(); - TransportPtr transport = - NamedTransportConfigFactory::getFactory(upstream_request_->transport_type_).createTransport(); Buffer::OwnedImpl transport_buffer; - upstream_request_->metadata_->setProtocol(upstream_request_->protocol_type_); + upstream_request_->metadata_->setProtocol(upstream_request_->protocol_->type()); - transport->encodeFrame(transport_buffer, *upstream_request_->metadata_, upstream_request_buffer_); + upstream_request_->transport_->encodeFrame(transport_buffer, *upstream_request_->metadata_, + upstream_request_buffer_); upstream_request_->conn_data_->connection().write(transport_buffer, false); upstream_request_->onRequestComplete(); return FilterStatus::Continue; @@ -228,16 +227,32 @@ FilterStatus Router::messageEnd() { void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { ASSERT(!upstream_request_->response_complete_); - if (!upstream_request_->response_started_) { - callbacks_->startUpstreamResponse(upstream_request_->transport_type_, - upstream_request_->protocol_type_); - upstream_request_->response_started_ = true; - } + if (upstream_request_->upgrade_response_ != nullptr) { + // Handle upgrade response. + if (!upstream_request_->upgrade_response_->onData(data)) { + // Wait for more data. + return; + } - if (callbacks_->upstreamData(data)) { - upstream_request_->onResponseComplete(); - cleanup(); - return; + upstream_request_->protocol_->completeUpgrade( + *upstream_request_->conn_data_->connectionStateTyped(), + *upstream_request_->upgrade_response_); + + upstream_request_->upgrade_response_.reset(); + upstream_request_->onRequestStart(true); + } else { + // Handle normal response. + if (!upstream_request_->response_started_) { + callbacks_->startUpstreamResponse(*upstream_request_->transport_, + *upstream_request_->protocol_); + upstream_request_->response_started_ = true; + } + + if (callbacks_->upstreamData(data)) { + upstream_request_->onResponseComplete(); + cleanup(); + return; + } } if (end_stream) { @@ -284,9 +299,10 @@ void Router::cleanup() { upstream_request_.reset(); } Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, MessageMetadataSharedPtr& metadata, TransportType transport_type, ProtocolType protocol_type) - : parent_(parent), conn_pool_(pool), metadata_(metadata), transport_type_(transport_type), - protocol_type_(protocol_type), request_complete_(false), response_started_(false), - response_complete_(false) {} + : parent_(parent), conn_pool_(pool), metadata_(metadata), + transport_(NamedTransportConfigFactory::getFactory(transport_type).createTransport()), + protocol_(NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol()), + request_complete_(false), response_started_(false), response_complete_(false) {} Router::UpstreamRequest::~UpstreamRequest() {} @@ -298,6 +314,11 @@ FilterStatus Router::UpstreamRequest::start() { return FilterStatus::StopIteration; } + if (upgrade_response_ != nullptr) { + // Pause while we wait for an upgrade response. + return FilterStatus::StopIteration; + } + return FilterStatus::Continue; } @@ -329,12 +350,28 @@ void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr onUpstreamHostSelected(host); conn_data_ = std::move(conn_data); conn_data_->addUpstreamCallbacks(parent_); - conn_pool_handle_ = nullptr; - parent_.initProtocolConverter( - NamedProtocolConfigFactory::getFactory(protocol_type_).createProtocol(), - parent_.upstream_request_buffer_); + ThriftConnectionState* state = conn_data_->connectionStateTyped(); + if (state == nullptr) { + conn_data_->setConnectionState(std::make_unique()); + state = conn_data_->connectionStateTyped(); + } + + if (protocol_->supportsUpgrade()) { + upgrade_response_ = + protocol_->attemptUpgrade(*transport_, *state, parent_.upstream_request_buffer_); + if (upgrade_response_ != nullptr) { + conn_data_->connection().write(parent_.upstream_request_buffer_, false); + return; + } + } + + onRequestStart(continue_decoding); +} + +void Router::UpstreamRequest::onRequestStart(bool continue_decoding) { + parent_.initProtocolConverter(*protocol_, parent_.upstream_request_buffer_); // TODO(zuercher): need to use an upstream-connection-specific sequence id parent_.convertMessageBegin(metadata_); diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 224b6fb09997c..8d6033bdb410a 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -15,6 +15,7 @@ #include "extensions/filters/network/thrift_proxy/conn_manager.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/router/router.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" #include "absl/types/optional.h" @@ -135,6 +136,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, void onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn, Upstream::HostDescriptionConstSharedPtr host) override; + void onRequestStart(bool continue_decoding); void onRequestComplete(); void onResponseComplete(); void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host); @@ -147,8 +149,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; Tcp::ConnectionPool::ConnectionDataPtr conn_data_; Upstream::HostDescriptionConstSharedPtr upstream_host_; - TransportType transport_type_; - ProtocolType protocol_type_; + TransportPtr transport_; + ProtocolPtr protocol_; + ThriftObjectPtr upgrade_response_; bool request_complete_ : 1; bool response_started_ : 1; diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object.h b/source/extensions/filters/network/thrift_proxy/thrift_object.h new file mode 100644 index 0000000000000..321e6496674cd --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" + +#include "extensions/filters/network/thrift_proxy/thrift.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class ThriftBase; + +/** + * ThriftValue is a field or container (list, set, or map) element. + */ +class ThriftValue { +public: + virtual ~ThriftValue() {} + + /** + * @return FieldType the type of this value + */ + virtual FieldType type() const PURE; + + /** + * @return const T& pointer to the value, provided that it can be cast to the given type + * @throw EnvoyException if the type T does not match the type + */ + template const T& getValueTyped() const { + // Use the Traits template to determine what FieldType the value must have to be cast to T + // and throw if the value's type doesn't match. + FieldType expected_field_type = Traits::getFieldType(); + if (expected_field_type != type()) { + throw EnvoyException(fmt::format("expected field type {}, got {}", + static_cast(expected_field_type), + static_cast(type()))); + } + + return *static_cast(getValue()); + } + +protected: + /** + * @return void* pointing to the underlying value, to be dynamically cast in getValueTyped + */ + virtual const void* getValue() const PURE; + +private: + /** + * Traits allows getValueTyped() to enforce that the field type is being cast to the desired type. + */ + template class Traits { + public: + // Compilation failures where T does not have a member getFieldType typically mean that + // getValueTyped was called with a type T that is not used to encode Thrift values. + // The specializations below encode the valid types for Thrift primitive types. + static FieldType getFieldType() { return T::getFieldType(); } + }; +}; + +// Explicit specializations of ThriftValue::Types for primitive types. +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Bool; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Byte; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I16; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I32; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I64; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Double; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::String; } +}; + +typedef std::unique_ptr ThriftValuePtr; +typedef std::list ThriftValuePtrList; +typedef std::list> ThriftValuePtrPairList; + +/** + * ThriftField is a field within a ThriftStruct. + */ +class ThriftField { +public: + virtual ~ThriftField() {} + + /** + * @return FieldType this field's type + */ + virtual FieldType fieldType() const PURE; + + /** + * @return int16_t the field's identifier + */ + virtual int16_t fieldId() const PURE; + + /** + * @return const ThriftValue& containing the field's value + */ + virtual const ThriftValue& getValue() const PURE; +}; + +typedef std::unique_ptr ThriftFieldPtr; +typedef std::list ThriftFieldPtrList; + +/** + * ThriftListValue is an ordered list of ThriftValues. + */ +class ThriftListValue { +public: + virtual ~ThriftListValue() {} + + /** + * @return const ThriftValuePtrList& containing the ThriftValues that comprise the list + */ + virtual const ThriftValuePtrList& elements() const PURE; + + /** + * @return FieldType of the underlying elements + */ + virtual FieldType elementType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::List; } +}; + +/** + * ThriftSetValue is a set of ThriftValues, maintained in their original order. + */ +class ThriftSetValue { +public: + virtual ~ThriftSetValue() {} + + /** + * @return const ThriftValuePtrList& containing the ThriftValues that comprise the set + */ + virtual const ThriftValuePtrList& elements() const PURE; + + /** + * @return FieldType of the underlying elements + */ + virtual FieldType elementType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Set; } +}; + +/** + * ThriftMapValue is a map of pairs of ThriftValues, maintained in their original order. + */ +class ThriftMapValue { +public: + virtual ~ThriftMapValue() {} + + /** + * @return const ThriftValuePtrPairList& containing the ThriftValue key-value paris that comprise + * the map. + */ + virtual const ThriftValuePtrPairList& elements() const PURE; + + /** + * @return FieldType of the underlying keys + */ + virtual FieldType keyType() const PURE; + + /** + * @return FieldType of the underlying values + */ + virtual FieldType valueType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Map; } +}; + +/** + * ThriftStructValue is a sequence of ThriftFields. + */ +class ThriftStructValue { +public: + virtual ~ThriftStructValue() {} + + /** + * @return const ThriftFieldPtrList& containing the ThriftFields that comprise the struct. + */ + virtual const ThriftFieldPtrList& fields() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Struct; } +}; + +/** + * ThriftObject is a ThrfitStructValue that can be read from a Buffer::Instance. + */ +class ThriftObject : public ThriftStructValue { +public: + virtual ~ThriftObject() {} + + /* + * Consumes bytes from the buffer until a single complete Thrift struct has been consumed. + * @param buffer starting with a Thrift struct + * @return true when a single complete struct has been consumed; false if more data is needed to + * complete decoding + * @throw EnvoyException if the struct is invalid + */ + virtual bool onData(Buffer::Instance& buffer) PURE; +}; + +typedef std::unique_ptr ThriftObjectPtr; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc new file mode 100644 index 0000000000000..fb9271331f3a2 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc @@ -0,0 +1,394 @@ +#include "extensions/filters/network/thrift_proxy/thrift_object_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace { + +std::unique_ptr makeValue(ThriftBase* parent, FieldType type) { + switch (type) { + case FieldType::Stop: + NOT_REACHED_GCOVR_EXCL_LINE; + + case FieldType::List: + return std::make_unique(parent); + + case FieldType::Set: + return std::make_unique(parent); + + case FieldType::Map: + return std::make_unique(parent); + + case FieldType::Struct: + return std::make_unique(parent); + + default: + return std::make_unique(parent, type); + } +} + +} // namespace + +ThriftBase::ThriftBase(ThriftBase* parent) : parent_(parent) {} + +FilterStatus ThriftBase::structBegin(absl::string_view name) { + ASSERT(delegate_ != nullptr); + return delegate_->structBegin(name); +} + +FilterStatus ThriftBase::structEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->structEnd(); +} + +FilterStatus ThriftBase::fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) { + ASSERT(delegate_ != nullptr); + return delegate_->fieldBegin(name, field_type, field_id); +} + +FilterStatus ThriftBase::fieldEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->fieldEnd(); +} + +FilterStatus ThriftBase::boolValue(bool value) { + ASSERT(delegate_ != nullptr); + return delegate_->boolValue(value); +} + +FilterStatus ThriftBase::byteValue(uint8_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->byteValue(value); +} + +FilterStatus ThriftBase::int16Value(int16_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int16Value(value); +} + +FilterStatus ThriftBase::int32Value(int32_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int32Value(value); +} + +FilterStatus ThriftBase::int64Value(int64_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int64Value(value); +} + +FilterStatus ThriftBase::doubleValue(double value) { + ASSERT(delegate_ != nullptr); + return delegate_->doubleValue(value); +} + +FilterStatus ThriftBase::stringValue(absl::string_view value) { + ASSERT(delegate_ != nullptr); + return delegate_->stringValue(value); +} + +FilterStatus ThriftBase::mapBegin(FieldType key_type, FieldType value_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->mapBegin(key_type, value_type, size); +} + +FilterStatus ThriftBase::mapEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->mapEnd(); +} + +FilterStatus ThriftBase::listBegin(FieldType elem_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->listBegin(elem_type, size); +} + +FilterStatus ThriftBase::listEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->listEnd(); +} + +FilterStatus ThriftBase::setBegin(FieldType elem_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->setBegin(elem_type, size); +} + +FilterStatus ThriftBase::setEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->setEnd(); +} + +void ThriftBase::delegateComplete() { + ASSERT(delegate_ != nullptr); + delegate_ = nullptr; +} + +ThriftFieldImpl::ThriftFieldImpl(ThriftStructValueImpl* parent, absl::string_view name, + FieldType field_type, int16_t field_id) + : ThriftBase(parent), name_(name), field_type_(field_type), field_id_(field_id) { + auto value = makeValue(this, field_type_); + delegate_ = value.get(); + value_ = std::move(value); +} + +FilterStatus ThriftFieldImpl::fieldEnd() { + if (delegate_) { + return delegate_->fieldEnd(); + } + + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftListValueImpl::listBegin(FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->listBegin(elem_type, size); + } + + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftListValueImpl::listEnd() { + if (delegate_) { + return delegate_->listEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftListValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + auto elem = makeValue(this, elem_type_); + delegate_ = elem.get(); + elements_.push_back(std::move(elem)); + remaining_--; +} + +FilterStatus ThriftSetValueImpl::setBegin(FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->setBegin(elem_type, size); + } + + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftSetValueImpl::setEnd() { + if (delegate_) { + return delegate_->setEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftSetValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + auto elem = makeValue(this, elem_type_); + delegate_ = elem.get(); + elements_.push_back(std::move(elem)); + remaining_--; +} + +FilterStatus ThriftMapValueImpl::mapBegin(FieldType key_type, FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->mapBegin(key_type, elem_type, size); + } + + key_type_ = key_type; + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftMapValueImpl::mapEnd() { + if (delegate_) { + return delegate_->mapEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftMapValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + // Prepare for first element's key. + if (elements_.empty()) { + auto key = makeValue(this, key_type_); + delegate_ = key.get(); + elements_.emplace_back(std::move(key), nullptr); + return; + } + + // Prepare for any elements's value. + auto& elem = elements_.back(); + if (elem.second == nullptr) { + auto value = makeValue(this, elem_type_); + delegate_ = value.get(); + elem.second = std::move(value); + + remaining_--; + return; + } + + // Key-value pair completed, prepare for next key. + auto key = makeValue(this, key_type_); + delegate_ = key.get(); + elements_.emplace_back(std::move(key), nullptr); +} + +FilterStatus ThriftValueImpl::boolValue(bool value) { + ASSERT(value_type_ == FieldType::Bool); + bool_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::byteValue(uint8_t value) { + ASSERT(value_type_ == FieldType::Byte); + byte_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int16Value(int16_t value) { + ASSERT(value_type_ == FieldType::I16); + int16_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int32Value(int32_t value) { + ASSERT(value_type_ == FieldType::I32); + int32_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int64Value(int64_t value) { + ASSERT(value_type_ == FieldType::I64); + int64_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::doubleValue(double value) { + ASSERT(value_type_ == FieldType::Double); + double_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::stringValue(absl::string_view value) { + ASSERT(value_type_ == FieldType::String); + string_value_ = std::string(value); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +const void* ThriftValueImpl::getValue() const { + switch (value_type_) { + case FieldType::Bool: + return &bool_value_; + case FieldType::Byte: + return &byte_value_; + case FieldType::I16: + return &int16_value_; + case FieldType::I32: + return &int32_value_; + case FieldType::I64: + return &int64_value_; + case FieldType::Double: + return &double_value_; + case FieldType::String: + return &string_value_; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +FilterStatus ThriftStructValueImpl::structBegin(absl::string_view name) { + if (delegate_) { + return delegate_->structBegin(name); + } + + return FilterStatus::Continue; +} + +FilterStatus ThriftStructValueImpl::structEnd() { + if (delegate_) { + return delegate_->structEnd(); + } + + if (parent_) { + parent_->delegateComplete(); + } + + return FilterStatus::Continue; +} + +FilterStatus ThriftStructValueImpl::fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) { + if (delegate_) { + return delegate_->fieldBegin(name, field_type, field_id); + } + + if (field_type != FieldType::Stop) { + auto field = std::make_unique(this, name, field_type, field_id); + delegate_ = field.get(); + fields_.emplace_back(std::move(field)); + } + + return FilterStatus::Continue; +} + +ThriftObjectImpl::ThriftObjectImpl(Transport& transport, Protocol& protocol) + : ThriftStructValueImpl(nullptr), + decoder_(std::make_unique(transport, protocol, *this)) {} + +bool ThriftObjectImpl::onData(Buffer::Instance& buffer) { + bool underflow = false; + auto result = decoder_->onData(buffer, underflow); + ASSERT(result == FilterStatus::Continue); + + if (complete_) { + decoder_.reset(); + } + return complete_; +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h new file mode 100644 index 0000000000000..b9057dfab2bc7 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h @@ -0,0 +1,262 @@ +#pragma once + +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ThriftBase is a base class for decoding Thrift objects. It implements methods from + * DecoderEventHandler to automatically delegate to an underlying ThriftBase so that, for example, + * the fieldBegin call for a struct field nested within a list is automatically delegated down the + * object hierarchy to the correct ThriftBase subclass. + */ +class ThriftBase : public DecoderEventHandler { +public: + ThriftBase(ThriftBase* parent); + ~ThriftBase() {} + + // DecoderEventHandler + FilterStatus transportBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus transportEnd() override { return FilterStatus::Continue; } + FilterStatus messageBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus messageEnd() override { return FilterStatus::Continue; } + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; + FilterStatus fieldEnd() override; + FilterStatus boolValue(bool value) override; + FilterStatus byteValue(uint8_t value) override; + FilterStatus int16Value(int16_t value) override; + FilterStatus int32Value(int32_t value) override; + FilterStatus int64Value(int64_t value) override; + FilterStatus doubleValue(double value) override; + FilterStatus stringValue(absl::string_view value) override; + FilterStatus mapBegin(FieldType key_type, FieldType value_type, uint32_t size) override; + FilterStatus mapEnd() override; + FilterStatus listBegin(FieldType elem_type, uint32_t size) override; + FilterStatus listEnd() override; + FilterStatus setBegin(FieldType elem_type, uint32_t size) override; + FilterStatus setEnd() override; + + // Invoked when the current delegate is complete. Completion implies that the delegate is fully + // specified (all list values processed, all struct fields processed, etc). + virtual void delegateComplete(); + +protected: + ThriftBase* parent_; + ThriftBase* delegate_{nullptr}; +}; + +/** + * ThriftValueBase is a base class for all struct field values, list values, set values, map keys, + * and map values. + */ +class ThriftValueBase : public ThriftValue, public ThriftBase { +public: + ThriftValueBase(ThriftBase* parent, FieldType value_type) + : ThriftBase(parent), value_type_(value_type) {} + ~ThriftValueBase() {} + + // ThriftValue + FieldType type() const override { return value_type_; } + +protected: + const FieldType value_type_; +}; + +class ThriftStructValueImpl; + +/** + * ThriftField represents a field in a thrift Struct. It always delegates DecoderEventHandler + * methods to a subclass of ThriftValueBase. + */ +class ThriftFieldImpl : public ThriftField, public ThriftBase { +public: + ThriftFieldImpl(ThriftStructValueImpl* parent, absl::string_view name, FieldType field_type, + int16_t field_id); + + // DecoderEventHandler + FilterStatus fieldEnd() override; + + // ThriftField + FieldType fieldType() const override { return field_type_; } + int16_t fieldId() const override { return field_id_; } + const ThriftValue& getValue() const override { return *value_; } + +private: + std::string name_; + FieldType field_type_; + int16_t field_id_; + ThriftValuePtr value_; +}; + +/** + * ThriftStructValueImpl implements ThriftStruct. + */ +class ThriftStructValueImpl : public ThriftStructValue, public ThriftValueBase { +public: + ThriftStructValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Struct) {} + + // DecoderEventHandler + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; + + // ThriftStructValue + const ThriftFieldPtrList& fields() const override { return fields_; } + +private: + // ThriftValue + const void* getValue() const override { return this; }; + + ThriftFieldPtrList fields_; +}; + +/** + * ThriftListValueImpl represents Thrift lists. + */ +class ThriftListValueImpl : public ThriftListValue, public ThriftValueBase { +public: + ThriftListValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::List) {} + + // DecoderEventHandler + FilterStatus listBegin(FieldType elem_type, uint32_t size) override; + FilterStatus listEnd() override; + + // ThriftListValue + const ThriftValuePtrList& elements() const override { return elements_; } + FieldType elementType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrList elements_; +}; + +/** + * ThriftSetValueImpl represents Thrift sets. + */ +class ThriftSetValueImpl : public ThriftSetValue, public ThriftValueBase { +public: + ThriftSetValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Set) {} + + // DecoderEventHandler + FilterStatus setBegin(FieldType elem_type, uint32_t size) override; + FilterStatus setEnd() override; + + // ThriftSetValue + const ThriftValuePtrList& elements() const override { return elements_; } + FieldType elementType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrList elements_; // maintain original order +}; + +/** + * ThriftMapValueImpl represents Thrift maps. + */ +class ThriftMapValueImpl : public ThriftMapValue, public ThriftValueBase { +public: + ThriftMapValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Map) {} + + // DecoderEventHandler + FilterStatus mapBegin(FieldType key_type, FieldType elem_type, uint32_t size) override; + FilterStatus mapEnd() override; + + // ThriftMapValue + const ThriftValuePtrPairList& elements() const override { return elements_; } + FieldType keyType() const override { return key_type_; } + FieldType valueType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType key_type_{FieldType::Stop}; + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrPairList elements_; // maintain original order +}; + +/** + * ThriftValueImpl represents primitive Thrift types, including strings. + */ +class ThriftValueImpl : public ThriftValueBase { +public: + ThriftValueImpl(ThriftBase* parent, FieldType value_type) : ThriftValueBase(parent, value_type) {} + + // DecoderEventHandler + FilterStatus boolValue(bool value) override; + FilterStatus byteValue(uint8_t value) override; + FilterStatus int16Value(int16_t value) override; + FilterStatus int32Value(int32_t value) override; + FilterStatus int64Value(int64_t value) override; + FilterStatus doubleValue(double value) override; + FilterStatus stringValue(absl::string_view value) override; + +protected: + // ThriftValue + const void* getValue() const override; + +private: + union { + bool bool_value_; + uint8_t byte_value_; + int16_t int16_value_; + int32_t int32_value_; + int64_t int64_value_; + double double_value_; + }; + std::string string_value_; +}; + +/** + * ThriftObjectImpl is a generic representation of a Thrift struct. + */ +class ThriftObjectImpl : public ThriftObject, + public ThriftStructValueImpl, + public DecoderCallbacks { +public: + ThriftObjectImpl(Transport& transport, Protocol& protocol); + + // DecoderCallbacks + DecoderEventHandler& newDecoderEventHandler() override { return *this; } + FilterStatus transportEnd() override { + complete_ = true; + return FilterStatus::Continue; + } + + // ThriftObject + bool onData(Buffer::Instance& buffer) override; + + // ThriftStruct + const ThriftFieldPtrList& fields() const override { return ThriftStructValueImpl::fields(); } + +private: + DecoderPtr decoder_; + bool complete_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index 1cc27974ea7b2..7a9d3c4b23ff3 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -245,6 +245,19 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "thrift_object_impl_test", + srcs = ["thrift_object_impl_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:thrift_object_lib", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + ], +) + envoy_extension_cc_test( name = "integration_test", srcs = ["integration_test.cc"], diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 43f3352149723..338ce458123df 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -2,9 +2,12 @@ #include "common/buffer/buffer_impl.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/buffer_helper.h" #include "extensions/filters/network/thrift_proxy/config.h" #include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/header_transport_impl.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" @@ -17,8 +20,10 @@ #include "gtest/gtest.h" using testing::_; +using testing::AnyNumber; using testing::Invoke; using testing::NiceMock; +using testing::Ref; using testing::Return; using testing::ReturnRef; @@ -39,10 +44,23 @@ class TestConfigImpl : public ConfigImpl { void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { callbacks.addDecoderFilter(decoder_filter_); } + TransportPtr createTransport() override { + if (transport_) { + return TransportPtr{transport_}; + } + return ConfigImpl::createTransport(); + } + ProtocolPtr createProtocol() override { + if (protocol_) { + return ProtocolPtr{protocol_}; + } + return ConfigImpl::createProtocol(); + } -private: ThriftFilters::DecoderFilterSharedPtr decoder_filter_; ThriftFilterStats& stats_; + MockTransport* transport_{}; + MockProtocol* protocol_{}; }; class ThriftConnectionManagerTest : public testing::Test { @@ -72,7 +90,14 @@ class ThriftConnectionManagerTest : public testing::Test { proto_config_.set_stat_prefix("test"); decoder_filter_.reset(new NiceMock()); + config_.reset(new TestConfigImpl(proto_config_, context_, decoder_filter_, stats_)); + if (custom_transport_) { + config_->transport_ = custom_transport_; + } + if (custom_protocol_) { + config_->protocol_ = custom_protocol_; + } filter_.reset(new ConnectionManager(*config_)); filter_->initializeReadFilterCallbacks(filter_callbacks_); @@ -267,6 +292,9 @@ class ThriftConnectionManagerTest : public testing::Test { Buffer::OwnedImpl write_buffer_; std::unique_ptr filter_; NiceMock filter_callbacks_; + + MockTransport* custom_transport_{}; + MockProtocol* custom_protocol_{}; }; TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { @@ -602,7 +630,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -634,7 +664,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { writeFramedBinaryTApplicationException(write_buffer_, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -667,7 +699,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { writeFramedBinaryIDLException(write_buffer_, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -700,7 +734,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { // Call is not valid in a response writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -739,7 +775,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { 0x08, 0xff, 0xff // illegal field id }); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_, write(_, false)); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); @@ -781,7 +819,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndTransportApplicationException) { 0x01, 0x02, 0x00, 0x00, // transforms: 1, 2; padding }); - callbacks->startUpstreamResponse(TransportType::Header, ProtocolType::Binary); + HeaderTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -817,15 +857,18 @@ TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(2); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); - callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + callbacks.front()->startUpstreamResponse(transport, proto); EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); callbacks.pop_front(); EXPECT_EQ(1U, store_.counter("test.response").value()); EXPECT_EQ(1U, store_.counter("test.response_reply").value()); writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x02); - callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + callbacks.front()->startUpstreamResponse(transport, proto); EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); callbacks.pop_front(); EXPECT_EQ(2U, store_.counter("test.response").value()); @@ -857,6 +900,68 @@ TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { EXPECT_EQ(0U, store_.gauge("test.request_active").value()); } +TEST_F(ThriftConnectionManagerTest, DownstreamProtocolUpgrade) { + custom_transport_ = new NiceMock(); + custom_protocol_ = new NiceMock(); + initializeFilter(); + + EXPECT_CALL(*custom_transport_, decodeFrameStart(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*custom_protocol_, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMessageType(MessageType::Call); + metadata.setProtocolUpgradeMessage(true); + return true; + })); + EXPECT_CALL(*custom_protocol_, supportsUpgrade()).Times(AnyNumber()).WillRepeatedly(Return(true)); + + MockDecoderEventHandler* upgrade_decoder = new NiceMock(); + EXPECT_CALL(*custom_protocol_, upgradeRequestDecoder()) + .WillOnce(Invoke([&]() -> DecoderEventHandlerSharedPtr { + return DecoderEventHandlerSharedPtr{upgrade_decoder}; + })); + EXPECT_CALL(*upgrade_decoder, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, structBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readFieldBegin(_, _, _, _)) + .WillOnce(Invoke( + [&](Buffer::Instance&, std::string&, FieldType& field_type, int16_t& field_id) -> bool { + field_type = FieldType::Stop; + field_id = 0; + return true; + })); + EXPECT_CALL(*custom_protocol_, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, structEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_transport_, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, transportEnd()).WillOnce(Return(FilterStatus::Continue)); + + MockDirectResponse* direct_response = new NiceMock(); + + EXPECT_CALL(*custom_protocol_, upgradeResponse(Ref(*upgrade_decoder))) + .WillOnce(Invoke([&](const DecoderEventHandler&) -> DirectResponsePtr { + return DirectResponsePtr{direct_response}; + })); + + EXPECT_CALL(*direct_response, encode(_, Ref(*custom_protocol_), _)) + .WillOnce(Invoke([&](MessageMetadata&, Protocol&, Buffer::Instance& buffer) -> void { + buffer.add("response"); + })); + EXPECT_CALL(*custom_transport_, encodeFrame(_, _, _)) + .WillOnce(Invoke( + [&](Buffer::Instance& buffer, const MessageMetadata&, Buffer::Instance& message) -> void { + EXPECT_EQ("response", message.toString()); + buffer.add("transport-encoded response"); + })); + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ("transport-encoded response", buffer.toString()); + })); + + Buffer::OwnedImpl buffer; + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 6fd60d2256814..a8369e757b220 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -710,17 +710,17 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { } TEST(DecoderTest, OnData) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; @@ -732,7 +732,7 @@ TEST(DecoderTest, OnData) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -750,18 +750,18 @@ TEST(DecoderTest, OnData) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); bool underflow = false; @@ -770,24 +770,24 @@ TEST(DecoderTest, OnData) { } TEST(DecoderTest, OnDataWithProtocolHint) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); metadata.setProtocol(ProtocolType::Binary); return true; })); - EXPECT_CALL(*proto, type()).WillOnce(Return(ProtocolType::Auto)); - EXPECT_CALL(*proto, setType(ProtocolType::Binary)); + EXPECT_CALL(proto, type()).WillOnce(Return(ProtocolType::Auto)); + EXPECT_CALL(proto, setType(ProtocolType::Binary)); EXPECT_CALL(handler, transportBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { EXPECT_TRUE(metadata->hasFrameSize()); @@ -799,7 +799,7 @@ TEST(DecoderTest, OnDataWithProtocolHint) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -817,18 +817,18 @@ TEST(DecoderTest, OnDataWithProtocolHint) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); bool underflow = false; @@ -837,23 +837,23 @@ TEST(DecoderTest, OnDataWithProtocolHint) { } TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); metadata.setProtocol(ProtocolType::Binary); return true; })); - EXPECT_CALL(*proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); + EXPECT_CALL(proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); bool underflow = false; EXPECT_THROW_WITH_MESSAGE(decoder.onData(buffer, underflow), EnvoyException, @@ -861,17 +861,17 @@ TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { } TEST(DecoderTest, OnDataThrowsTransportAppException) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setAppException(AppExceptionType::InvalidTransform, "unknown xform"); return true; @@ -882,85 +882,85 @@ TEST(DecoderTest, OnDataThrowsTransportAppException) { } TEST(DecoderTest, OnDataResumes) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; buffer.add("x"); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(false)); bool underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); // buffer.length() == 1 } TEST(DecoderTest, OnDataResumesTransportFrameStart) { - StrictMock* transport = new StrictMock(); - StrictMock* proto = new StrictMock(); + StrictMock transport; + StrictMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); - EXPECT_CALL(*transport, name()).Times(AnyNumber()); - EXPECT_CALL(*proto, name()).Times(AnyNumber()); + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; bool underflow = false; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); @@ -968,66 +968,65 @@ TEST(DecoderTest, OnDataResumesTransportFrameStart) { } TEST(DecoderTest, OnDataResumesTransportFrameEnd) { - StrictMock* transport = new StrictMock(); - StrictMock* proto = new StrictMock(); + StrictMock transport; + StrictMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); - EXPECT_CALL(*transport, name()).Times(AnyNumber()); - EXPECT_CALL(*proto, name()).Times(AnyNumber()); + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(false)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(false)); bool underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); // buffer.length() == 0 } TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { + StrictMock transport; + EXPECT_CALL(transport, name()).WillRepeatedly(ReturnRef(transport.name_)); - StrictMock* transport = new StrictMock(); - EXPECT_CALL(*transport, name()).WillRepeatedly(ReturnRef(transport->name_)); - - StrictMock* proto = new StrictMock(); - EXPECT_CALL(*proto, name()).WillRepeatedly(ReturnRef(proto->name_)); + StrictMock proto; + EXPECT_CALL(proto, name()).WillRepeatedly(ReturnRef(proto.name_)); NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; bool underflow = true; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; @@ -1042,7 +1041,7 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -1062,42 +1061,42 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())) .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::I32), SetArgReferee<3>(1), Return(true))); EXPECT_CALL(handler, fieldBegin(absl::string_view(), FieldType::I32, 1)) .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readInt32(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(true)); EXPECT_CALL(handler, int32Value(_)).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, fieldEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index ad14d3c1149e1..ac0359dce7c17 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -27,6 +27,7 @@ MockProtocol::MockProtocol() { ON_CALL(*this, setType(_)).WillByDefault(Invoke([&](ProtocolType type) -> void { type_ = type; })); + ON_CALL(*this, supportsUpgrade()).WillByDefault(Return(false)); } MockProtocol::~MockProtocol() {} @@ -39,6 +40,9 @@ MockDecoderEventHandler::~MockDecoderEventHandler() {} MockDirectResponse::MockDirectResponse() {} MockDirectResponse::~MockDirectResponse() {} +MockThriftObject::MockThriftObject() {} +MockThriftObject::~MockThriftObject() {} + namespace ThriftFilters { MockDecoderFilter::MockDecoderFilter() { diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index b93b7717c8780..2067434b2db88 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -1,6 +1,7 @@ #pragma once #include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/conn_state.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/protocol.h" @@ -101,6 +102,12 @@ class MockProtocol : public Protocol { MOCK_METHOD2(writeDouble, void(Buffer::Instance& buffer, double value)); MOCK_METHOD2(writeString, void(Buffer::Instance& buffer, const std::string& value)); MOCK_METHOD2(writeBinary, void(Buffer::Instance& buffer, const std::string& value)); + MOCK_METHOD0(supportsUpgrade, bool()); + MOCK_METHOD0(upgradeRequestDecoder, DecoderEventHandlerSharedPtr()); + MOCK_METHOD1(upgradeResponse, DirectResponsePtr(const DecoderEventHandler&)); + MOCK_METHOD3(attemptUpgrade, + ThriftObjectPtr(Transport&, ThriftConnectionState&, Buffer::Instance&)); + MOCK_METHOD2(completeUpgrade, void(ThriftConnectionState&, ThriftObject&)); std::string name_{"mock"}; ProtocolType type_{ProtocolType::Auto}; @@ -154,6 +161,15 @@ class MockDirectResponse : public DirectResponse { MOCK_CONST_METHOD3(encode, void(MessageMetadata&, Protocol&, Buffer::Instance&)); }; +class MockThriftObject : public ThriftObject { +public: + MockThriftObject(); + ~MockThriftObject(); + + MOCK_CONST_METHOD0(fields, ThriftFieldPtrList&()); + MOCK_METHOD1(onData, bool(Buffer::Instance&)); +}; + namespace ThriftFilters { class MockDecoderFilter : public DecoderFilter { @@ -204,7 +220,7 @@ class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); MOCK_METHOD1(sendLocalReply, void(const DirectResponse&)); - MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); + MOCK_METHOD2(startUpstreamResponse, void(Transport&, Protocol&)); MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); MOCK_METHOD0(resetDownstreamConnection, void()); diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 3b07f4d9186dd..31ce50e8327c7 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -21,6 +21,7 @@ using testing::_; using testing::ContainsRegex; +using testing::InSequence; using testing::Invoke; using testing::NiceMock; using testing::Ref; @@ -71,8 +72,22 @@ class TestNamedProtocolConfigFactory : public NamedProtocolConfigFactory { class ThriftRouterTestBase { public: ThriftRouterTestBase() - : transport_factory_([&]() -> MockTransport* { return transport_; }), - protocol_factory_([&]() -> MockProtocol* { return protocol_; }), + : transport_factory_([&]() -> MockTransport* { + ASSERT(transport_ == nullptr); + transport_ = new NiceMock(); + if (mock_transport_cb_) { + mock_transport_cb_(transport_); + } + return transport_; + }), + protocol_factory_([&]() -> MockProtocol* { + ASSERT(protocol_ == nullptr); + protocol_ = new NiceMock(); + if (mock_protocol_cb_) { + mock_protocol_cb_(protocol_); + } + return protocol_; + }), transport_register_(transport_factory_), protocol_register_(protocol_factory_) {} void initializeRouter() { @@ -124,9 +139,6 @@ class ThriftRouterTestBase { upstream_callbacks_ = &cb; })); - protocol_ = new NiceMock(); - - ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { EXPECT_EQ(metadata_->methodName(), metadata.methodName()); @@ -165,15 +177,16 @@ class ThriftRouterTestBase { EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); - protocol_ = new NiceMock(); - ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); - EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) - .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { - EXPECT_EQ(metadata_->methodName(), metadata.methodName()); - EXPECT_EQ(metadata_->messageType(), metadata.messageType()); - EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); - })); + mock_protocol_cb_ = [&](MockProtocol* protocol) -> void { + ON_CALL(*protocol, type()).WillByDefault(Return(ProtocolType::Binary)); + EXPECT_CALL(*protocol, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + }; EXPECT_CALL(callbacks_, continueDecoding()).Times(0); EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_)) .WillOnce( @@ -240,8 +253,6 @@ class ThriftRouterTestBase { } void completeRequest() { - transport_ = new NiceMock(); - EXPECT_CALL(*protocol_, writeMessageEnd(_)); EXPECT_CALL(*transport_, encodeFrame(_, _, _)); EXPECT_CALL(upstream_connection_, write(_, false)); @@ -257,7 +268,7 @@ class ThriftRouterTestBase { void returnResponse() { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); upstream_callbacks_->onUpstreamData(buffer, false); @@ -277,6 +288,9 @@ class ThriftRouterTestBase { Registry::InjectFactory transport_register_; Registry::InjectFactory protocol_register_; + std::function mock_transport_cb_{}; + std::function mock_protocol_cb_{}; + NiceMock context_; NiceMock callbacks_; NiceMock* transport_{}; @@ -472,7 +486,7 @@ TEST_F(ThriftRouterTest, TruncatedResponse) { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, released(Ref(upstream_connection_))); EXPECT_CALL(callbacks_, resetDownstreamConnection()); @@ -531,7 +545,7 @@ TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))) .WillOnce(Invoke([&](Buffer::Instance&) -> bool { router_->resetUpstreamConnection(); @@ -590,6 +604,100 @@ TEST_F(ThriftRouterTest, UnexpectedRouterDestroy) { destroyRouter(); } +TEST_F(ThriftRouterTest, ProtocolUpgrade) { + initializeRouter(); + startRequest(MessageType::Call); + + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { upstream_callbacks_ = &cb; })); + + Tcp::ConnectionPool::ConnectionStatePtr conn_state; + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, setConnectionState_(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { conn_state.swap(cs); })); + + EXPECT_CALL(*protocol_, supportsUpgrade()).WillOnce(Return(true)); + + MockThriftObject* upgrade_response = new NiceMock(); + + EXPECT_CALL(*protocol_, attemptUpgrade(_, _, _)) + .WillOnce(Invoke( + [&](Transport&, ThriftConnectionState&, Buffer::Instance& buffer) -> ThriftObjectPtr { + buffer.add("upgrade request"); + return ThriftObjectPtr{upgrade_response}; + })); + EXPECT_CALL(upstream_connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ("upgrade request", buffer.toString()); + })); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + + Buffer::OwnedImpl buffer; + EXPECT_CALL(*upgrade_response, onData(Ref(buffer))).WillOnce(Return(false)); + upstream_callbacks_->onUpstreamData(buffer, false); + + EXPECT_CALL(*upgrade_response, onData(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(*protocol_, completeUpgrade(_, Ref(*upgrade_response))); + EXPECT_CALL(callbacks_, continueDecoding()); + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + upstream_callbacks_->onUpstreamData(buffer, false); + + // Then the actual request... + sendTrivialStruct(FieldType::String); + completeRequest(); + returnResponse(); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, ProtocolUpgradeSkippedOnExistingConnection) { + initializeRouter(); + startRequest(MessageType::Call); + + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { upstream_callbacks_ = &cb; })); + + Tcp::ConnectionPool::ConnectionStatePtr conn_state = std::make_unique(); + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + + EXPECT_CALL(*protocol_, supportsUpgrade()).WillOnce(Return(true)); + + // Protocol determines that connection state shows upgrade already occurred + EXPECT_CALL(*protocol_, attemptUpgrade(_, _, _)) + .WillOnce(Invoke([&](Transport&, ThriftConnectionState&, + Buffer::Instance&) -> ThriftObjectPtr { return nullptr; })); + + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + EXPECT_CALL(callbacks_, continueDecoding()); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + + // Then the actual request... + sendTrivialStruct(FieldType::String); + completeRequest(); + returnResponse(); + destroyRouter(); +} + TEST_P(ThriftRouterFieldTypeTest, OneWay) { FieldType field_type = GetParam(); diff --git a/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc b/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc new file mode 100644 index 0000000000000..3e8d12403dcd8 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc @@ -0,0 +1,494 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/thrift_object_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::AnyNumber; +using testing::Expectation; +using testing::ExpectationSet; +using testing::InSequence; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; +using testing::Test; +using testing::TestWithParam; +using testing::Values; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class ThriftObjectImplTestBase { +public: + virtual ~ThriftObjectImplTestBase() {} + + Expectation expectValue(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + return EXPECT_CALL(protocol_, readBool(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, bool& value) -> bool { + value = true; + return true; + })); + case FieldType::Byte: + return EXPECT_CALL(protocol_, readByte(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, uint8_t& value) -> bool { + value = 1; + return true; + })); + case FieldType::Double: + return EXPECT_CALL(protocol_, readDouble(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, double& value) -> bool { + value = 2.0; + return true; + })); + case FieldType::I16: + return EXPECT_CALL(protocol_, readInt16(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int16_t& value) -> bool { + value = 3; + return true; + })); + case FieldType::I32: + return EXPECT_CALL(protocol_, readInt32(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int32_t& value) -> bool { + value = 4; + return true; + })); + case FieldType::I64: + return EXPECT_CALL(protocol_, readInt64(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int64_t& value) -> bool { + value = 5; + return true; + })); + case FieldType::String: + return EXPECT_CALL(protocol_, readString(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, std::string& value) -> bool { + value = "six"; + return true; + })); + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + Expectation expectFieldBegin(FieldType field_type, int16_t field_id) { + return EXPECT_CALL(protocol_, readFieldBegin(Ref(buffer_), _, _, _)) + .WillOnce( + Invoke([=](Buffer::Instance&, std::string&, FieldType& type, int16_t& id) -> bool { + type = field_type; + id = field_id; + return true; + })); + } + + Expectation expectFieldEnd() { + return EXPECT_CALL(protocol_, readFieldEnd(Ref(buffer_))).WillOnce(Return(true)); + } + + ExpectationSet expectField(FieldType field_type, int16_t field_id) { + ExpectationSet s; + s += expectFieldBegin(field_type, field_id); + s += expectValue(field_type); + s += expectFieldEnd(); + return s; + } + + Expectation expectStopField() { return expectFieldBegin(FieldType::Stop, 0); } + + void checkValue(FieldType field_type, const ThriftValue& value) { + EXPECT_EQ(field_type, value.type()); + + switch (field_type) { + case FieldType::Bool: + EXPECT_EQ(true, value.getValueTyped()); + break; + case FieldType::Byte: + EXPECT_EQ(1, value.getValueTyped()); + break; + case FieldType::Double: + EXPECT_EQ(2.0, value.getValueTyped()); + break; + case FieldType::I16: + EXPECT_EQ(3, value.getValueTyped()); + break; + case FieldType::I32: + EXPECT_EQ(4, value.getValueTyped()); + break; + case FieldType::I64: + EXPECT_EQ(5, value.getValueTyped()); + break; + case FieldType::String: + EXPECT_EQ("six", value.getValueTyped()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + void checkFieldValue(const ThriftField& field) { + const ThriftValue& value = field.getValue(); + checkValue(field.fieldType(), value); + } + + NiceMock transport_; + NiceMock protocol_; + Buffer::OwnedImpl buffer_; +}; + +class ThriftObjectImplTest : public ThriftObjectImplTestBase, public Test {}; + +// Test parsing an empty struct (just a stop field). +TEST_F(ThriftObjectImplTest, ParseEmptyStruct) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_TRUE(thrift_obj.fields().empty()); +} + +class ThriftObjectImplValueTest : public ThriftObjectImplTestBase, + public TestWithParam {}; + +INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, ThriftObjectImplValueTest, + Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, + FieldType::I32, FieldType::I64, FieldType::String), + fieldTypeParamToString); + +// Test parsing a struct with a single field with a simple value. +TEST_P(ThriftObjectImplValueTest, ParseSingleValueStruct) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(field_type, 1); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + EXPECT_EQ(field_type, thrift_obj.fields().front()->fieldType()); + EXPECT_EQ(1, thrift_obj.fields().front()->fieldId()); + checkFieldValue(*thrift_obj.fields().front()); +} + +// Test parsing nested structs (struct -> struct -> simple field). +TEST_P(ThriftObjectImplValueTest, ParseNestedSingleValueStruct) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Struct, 1); + + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(field_type, 2); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(FieldType::Struct, field.fieldType()); + + const ThriftStructValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(1, nested.fields().size()); + EXPECT_EQ(field_type, nested.fields().front()->fieldType()); + EXPECT_EQ(2, nested.fields().front()->fieldId()); + checkFieldValue(*nested.fields().front()); +} + +// Test parsing a struct with a single list field (struct -> list). +TEST_P(ThriftObjectImplValueTest, ParseNestedListValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::List, 1); + + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& type, uint32_t& size) -> bool { + type = field_type; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(field_type); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::List, field.fieldType()); + + const ThriftListValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.elementType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value); + } +} + +// Test parsing a struct with a single set field (struct -> set). +TEST_P(ThriftObjectImplValueTest, ParseNestedSetValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Set, 1); + + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& type, uint32_t& size) -> bool { + type = field_type; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(field_type); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::Set, field.fieldType()); + + const ThriftSetValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.elementType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value); + } +} + +// Test parsing a struct with a single map field (struct -> map). +TEST_P(ThriftObjectImplValueTest, ParseNestedMapValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Map, 1); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = field_type; + value_type = FieldType::String; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(FieldType::String); + expectValue(field_type); + expectValue(FieldType::String); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::Map, field.fieldType()); + + const ThriftMapValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.keyType()); + EXPECT_EQ(FieldType::String, nested.valueType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value.first); + checkValue(FieldType::String, *value.second); + } +} + +// Test a struct with a map -> list -> set -> map -> list -> set -> struct. +TEST_F(ThriftObjectImplTest, DeeplyNestedStruct) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Map, 1); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = FieldType::I32; + value_type = FieldType::List; + size = 1; + return true; + })); + expectValue(FieldType::I32); + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Set; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Map; + size = 1; + return true; + })); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = FieldType::I32; + value_type = FieldType::List; + size = 1; + return true; + })); + expectValue(FieldType::I32); + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Set; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Struct; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(FieldType::I64, 100); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + + EXPECT_EQ(FieldType::Map, thrift_obj.fields().front()->fieldType()); + const ThriftMapValue& map_value = + thrift_obj.fields().front()->getValue().getValueTyped(); + EXPECT_EQ(1, map_value.elements().size()); + + const ThriftListValue& list_value = + map_value.elements().front().second->getValueTyped(); + EXPECT_EQ(1, list_value.elements().size()); + + const ThriftSetValue& set_value = list_value.elements().front()->getValueTyped(); + EXPECT_EQ(1, set_value.elements().size()); + + const ThriftMapValue& map_value2 = set_value.elements().front()->getValueTyped(); + EXPECT_EQ(1, map_value2.elements().size()); + + const ThriftListValue& list_value2 = + map_value2.elements().front().second->getValueTyped(); + EXPECT_EQ(1, list_value2.elements().size()); + + const ThriftSetValue& set_value2 = + list_value2.elements().front()->getValueTyped(); + EXPECT_EQ(1, set_value2.elements().size()); + + const ThriftStructValue& struct_value = + set_value2.elements().front()->getValueTyped(); + EXPECT_EQ(1, struct_value.fields().size()); + + EXPECT_EQ(5, struct_value.fields().front()->getValue().getValueTyped()); +} + +// Tests when caller requests wrong value type. +TEST_F(ThriftObjectImplTest, WrongValueType) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(FieldType::String, 1); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + + const ThriftValue& value = thrift_obj.fields().front()->getValue(); + EXPECT_THROW_WITH_MESSAGE(value.getValueTyped(), EnvoyException, + fmt::format("expected field type {}, got {}", + static_cast(FieldType::I32), + static_cast(FieldType::String))); +} + +} // Namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy From ee710d0a92b476cc3d7a0f99f11b62675d1a4e01 Mon Sep 17 00:00:00 2001 From: stevenzzzz Date: Wed, 5 Sep 2018 18:04:48 -0400 Subject: [PATCH 07/15] Add terminal attribute to request hash. (#4292) Add a terminal attribute to request hash policy. Think about a case where customers want to hash on a cookie if it's present but if it's not present, do best-effort sticky based on something like IP so the customer has a stable hash. This "terminal" allows request hashing to have the ability of "if A not working, fallback to B.", which also saves time to generate the hash. Changes: * Add a terminal attribute to HashMethod, which shortcircuit the hash generating process if a policy is marked terminal and there is a hash computed already. Signed-off-by: Xin Zhuang stevenzzz@google.com Description: Add terminal attribute to request hash. Risk Level: Low Testing: unit tests. --- api/envoy/api/v2/route/route.proto | 25 ++++++++++- source/common/router/config_impl.cc | 40 ++++++++++++----- source/common/router/config_impl.h | 3 ++ test/common/router/config_impl_test.cc | 59 ++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 12 deletions(-) diff --git a/api/envoy/api/v2/route/route.proto b/api/envoy/api/v2/route/route.proto index e260cff69a963..a1f58891c8788 100644 --- a/api/envoy/api/v2/route/route.proto +++ b/api/envoy/api/v2/route/route.proto @@ -585,6 +585,27 @@ message RouteAction { // Connection properties hash policy. ConnectionProperties connection_properties = 3; } + + // The flag that shortcircuits the hash computing. This field provides a + // 'fallback' style of configuration: "if a terminal policy doesn't work, + // fallback to rest of the policy list", it saves time when the terminal + // policy works. + // + // If true, and there is already a hash computed, ignore rest of the + // list of hash polices. + // For example, if the following hash methods are configured: + // + // ========= ======== + // specifier terminal + // ========= ======== + // Header A true + // Header B false + // Header C false + // ========= ======== + // + // The generateHash process ends if policy "header A" generates a hash, as + // it's a terminal policy. + bool terminal = 4; } // Specifies a list of hash policies to use for ring hash load balancing. Each @@ -596,7 +617,9 @@ message RouteAction { // hash policies fail to generate a hash, no hash will be produced for // the route. In this case, the behavior is the same as if no hash policies // were specified (i.e. the ring hash load balancer will choose a random - // backend). + // backend). If a hash policy has the "terminal" attribute set to true, and + // there is already a hash generated, the hash is returned immediately, + // ignoring the rest of the hash policy list. repeated HashPolicy hash_policy = 15; // Indicates that a HTTP/1.1 client connection to this particular route is allowed to diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc index 21b891ae79b36..a48bf4e3baef0 100644 --- a/source/common/router/config_impl.cc +++ b/source/common/router/config_impl.cc @@ -76,9 +76,20 @@ ShadowPolicyImpl::ShadowPolicyImpl(const envoy::api::v2::route::RouteAction& con runtime_key_ = config.request_mirror_policy().runtime_key(); } -class HeaderHashMethod : public HashPolicyImpl::HashMethod { +class HashMethodImplBase : public HashPolicyImpl::HashMethod { public: - HeaderHashMethod(const std::string& header_name) : header_name_(header_name) {} + HashMethodImplBase(bool terminal) : terminal_(terminal) {} + + bool terminal() const override { return terminal_; } + +private: + const bool terminal_; +}; + +class HeaderHashMethod : public HashMethodImplBase { +public: + HeaderHashMethod(const std::string& header_name, bool terminal) + : HashMethodImplBase(terminal), header_name_(header_name) {} absl::optional evaluate(const Network::Address::Instance*, const Http::HeaderMap& headers, @@ -96,18 +107,17 @@ class HeaderHashMethod : public HashPolicyImpl::HashMethod { const Http::LowerCaseString header_name_; }; -class CookieHashMethod : public HashPolicyImpl::HashMethod { +class CookieHashMethod : public HashMethodImplBase { public: CookieHashMethod(const std::string& key, const std::string& path, - const absl::optional& ttl) - : key_(key), path_(path), ttl_(ttl) {} + const absl::optional& ttl, bool terminal) + : HashMethodImplBase(terminal), key_(key), path_(path), ttl_(ttl) {} absl::optional evaluate(const Network::Address::Instance*, const Http::HeaderMap& headers, const HashPolicy::AddCookieCallback add_cookie) const override { absl::optional hash; std::string value = Http::Utility::parseCookieValue(headers, key_); - if (value.empty() && ttl_.has_value()) { value = add_cookie(key_, path_, ttl_.value()); hash = HashUtil::xxHash64(value); @@ -124,8 +134,10 @@ class CookieHashMethod : public HashPolicyImpl::HashMethod { const absl::optional ttl_; }; -class IpHashMethod : public HashPolicyImpl::HashMethod { +class IpHashMethod : public HashMethodImplBase { public: + IpHashMethod(bool terminal) : HashMethodImplBase(terminal) {} + absl::optional evaluate(const Network::Address::Instance* downstream_addr, const Http::HeaderMap&, const HashPolicy::AddCookieCallback) const override { @@ -153,20 +165,21 @@ HashPolicyImpl::HashPolicyImpl( for (auto& hash_policy : hash_policies) { switch (hash_policy.policy_specifier_case()) { case envoy::api::v2::route::RouteAction::HashPolicy::kHeader: - hash_impls_.emplace_back(new HeaderHashMethod(hash_policy.header().header_name())); + hash_impls_.emplace_back( + new HeaderHashMethod(hash_policy.header().header_name(), hash_policy.terminal())); break; case envoy::api::v2::route::RouteAction::HashPolicy::kCookie: { absl::optional ttl; if (hash_policy.cookie().has_ttl()) { ttl = std::chrono::seconds(hash_policy.cookie().ttl().seconds()); } - hash_impls_.emplace_back( - new CookieHashMethod(hash_policy.cookie().name(), hash_policy.cookie().path(), ttl)); + hash_impls_.emplace_back(new CookieHashMethod( + hash_policy.cookie().name(), hash_policy.cookie().path(), ttl, hash_policy.terminal())); break; } case envoy::api::v2::route::RouteAction::HashPolicy::kConnectionProperties: if (hash_policy.connection_properties().source_ip()) { - hash_impls_.emplace_back(new IpHashMethod()); + hash_impls_.emplace_back(new IpHashMethod(hash_policy.terminal())); } break; default: @@ -190,6 +203,11 @@ HashPolicyImpl::generateHash(const Network::Address::Instance* downstream_addr, const uint64_t old_value = hash ? ((hash.value() << 1) | (hash.value() >> 63)) : 0; hash = old_value ^ new_hash.value(); } + // If the policy is a terminal policy and a hash has been generated, ignore + // the rest of the hash policies. + if (hash_impl->terminal() && hash) { + break; + } } return hash; } diff --git a/source/common/router/config_impl.h b/source/common/router/config_impl.h index 238f54b99d51b..3aa34f0a2b925 100644 --- a/source/common/router/config_impl.h +++ b/source/common/router/config_impl.h @@ -234,6 +234,9 @@ class HashPolicyImpl : public HashPolicy { virtual absl::optional evaluate(const Network::Address::Instance* downstream_addr, const Http::HeaderMap& headers, const AddCookieCallback add_cookie) const PURE; + + // If the method is a terminal method, ignore rest of the hash policy chain. + virtual bool terminal() const PURE; }; typedef std::unique_ptr HashMethodPtr; diff --git a/test/common/router/config_impl_test.cc b/test/common/router/config_impl_test.cc index 2b1e7fe587a1c..b063d6408bf79 100644 --- a/test/common/router/config_impl_test.cc +++ b/test/common/router/config_impl_test.cc @@ -1724,6 +1724,65 @@ TEST_F(RouterMatcherHashPolicyTest, HashMultiple) { EXPECT_NE(hash_h, hash_both); } +TEST_F(RouterMatcherHashPolicyTest, HashTerminal) { + // Hash policy list: cookie, header [terminal=true], user_ip. + auto route = route_config_.mutable_virtual_hosts(0)->mutable_routes(0)->mutable_route(); + route->add_hash_policy()->mutable_cookie()->set_name("cookie_hash"); + auto* header_hash = route->add_hash_policy(); + header_hash->mutable_header()->set_header_name("foo_header"); + header_hash->set_terminal(true); + route->add_hash_policy()->mutable_connection_properties()->set_source_ip(true); + Network::Address::Ipv4Instance address1("4.3.2.1"); + Network::Address::Ipv4Instance address2("1.2.3.4"); + + uint64_t hash_1, hash_2; + // Test terminal works when there is hash computed, the rest of the policy + // list is ignored. + { + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + headers.addCopy("Cookie", "cookie_hash=foo;"); + headers.addCopy("foo_header", "bar"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_1 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address1, headers, add_cookie_nop_) + .value(); + } + { + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + headers.addCopy("Cookie", "cookie_hash=foo;"); + headers.addCopy("foo_header", "bar"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_2 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address2, headers, add_cookie_nop_) + .value(); + } + EXPECT_EQ(hash_1, hash_2); + + // If no hash computed after evaluating a hash policy, the rest of the policy + // list is evaluated. + { + // Input: {}, {}, address1. Hash on address1. + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_1 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address1, headers, add_cookie_nop_) + .value(); + } + { + // Input: {}, {}, address2. Hash on address2. + Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/foo", "GET"); + Router::RouteConstSharedPtr route = config().route(headers, 0); + hash_2 = route->routeEntry() + ->hashPolicy() + ->generateHash(&address2, headers, add_cookie_nop_) + .value(); + } + EXPECT_NE(hash_1, hash_2); +} + TEST_F(RouterMatcherHashPolicyTest, InvalidHashPolicies) { NiceMock factory_context; { From cd35a2cf49ebe35c03e8694d3a546347f4d372f6 Mon Sep 17 00:00:00 2001 From: JimmyCYJ Date: Wed, 5 Sep 2018 15:13:41 -0700 Subject: [PATCH 08/15] Revise per comments. Signed-off-by: JimmyCYJ --- source/common/secret/sds_api.cc | 25 +++++++------- source/common/secret/sds_api.h | 42 ++++++++---------------- source/common/ssl/context_config_impl.cc | 13 ++++++-- source/common/ssl/context_config_impl.h | 25 ++++++++++---- 4 files changed, 56 insertions(+), 49 deletions(-) diff --git a/source/common/secret/sds_api.cc b/source/common/secret/sds_api.cc index 7b475d151ae33..a13afc8e5c9bf 100644 --- a/source/common/secret/sds_api.cc +++ b/source/common/secret/sds_api.cc @@ -13,11 +13,12 @@ namespace Envoy { namespace Secret { -SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, - Runtime::RandomGenerator& random, Stats::Store& stats, - Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, - const envoy::api::v2::core::ConfigSource& sds_config, std::string sds_config_name, - std::function destructor_cb) +template +SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, + Runtime::RandomGenerator& random, Stats::Store& stats, + Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, + const envoy::api::v2::core::ConfigSource& sds_config, + std::string sds_config_name, std::function destructor_cb) : secret_hash_(0), local_info_(local_info), dispatcher_(dispatcher), random_(random), stats_(stats), cluster_manager_(cluster_manager), sds_config_(sds_config), sds_config_name_(sds_config_name), clean_up_(destructor_cb) { @@ -28,7 +29,7 @@ SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispat init_manager.registerTarget(*this); } -void SdsApi::initialize(std::function callback) { +template void SdsApi::initialize(std::function callback) { initialize_callback_ = callback; subscription_ = Envoy::Config::SubscriptionFactory::subscriptionFromConfigSource< @@ -42,7 +43,8 @@ void SdsApi::initialize(std::function callback) { subscription_->start({sds_config_name_}, *this); } -void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) { +template +void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) { if (resources.empty()) { throw EnvoyException( fmt::format("Missing SDS resources for {} in onConfigUpdate()", sds_config_name_)); @@ -65,12 +67,12 @@ void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) runInitializeCallbackIfAny(); } -void SdsApi::onConfigUpdateFailed(const EnvoyException*) { +template void SdsApi::onConfigUpdateFailed(const EnvoyException*) { // We need to allow server startup to continue, even if we have a bad config. runInitializeCallbackIfAny(); } -void SdsApi::runInitializeCallbackIfAny() { +template void SdsApi::runInitializeCallbackIfAny() { if (initialize_callback_) { initialize_callback_(); initialize_callback_ = nullptr; @@ -82,8 +84,7 @@ void TlsCertificateSdsApi::updateConfigHelper(const envoy::api::v2::auth::Secret if (new_hash != secret_hash_ && secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kTlsCertificate) { secret_hash_ = new_hash; - tls_certificate_secrets_ = - std::make_unique(secret.tls_certificate()); + secrets_ = std::make_unique(secret.tls_certificate()); update_callback_manager_.runCallbacks(); } @@ -95,7 +96,7 @@ void CertificateValidationContextSdsApi::updateConfigHelper( if (new_hash != secret_hash_ && secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kValidationContext) { secret_hash_ = new_hash; - certificate_validation_context_secrets_ = + secrets_ = std::make_unique(secret.validation_context()); update_callback_manager_.runCallbacks(); diff --git a/source/common/secret/sds_api.h b/source/common/secret/sds_api.h index 224994b683852..6d4bb1d1cf8e8 100644 --- a/source/common/secret/sds_api.h +++ b/source/common/secret/sds_api.h @@ -23,7 +23,9 @@ namespace Secret { /** * SDS API implementation that fetches secrets from SDS server via Subscription. */ +template class SdsApi : public Init::Target, + public SecretProvider, public Config::SubscriptionCallbacks { public: SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, @@ -42,10 +44,19 @@ class SdsApi : public Init::Target, return MessageUtil::anyConvert(resource).name(); } + // SecretProvider + const SecretType* secret() const override { return secrets_.get(); } + + Common::CallbackHandle* addUpdateCallback(std::function callback) override { + return update_callback_manager_.add(callback); + } + protected: // Updates local storage of dynamic secrets and invokes callbacks. - virtual void updateConfigHelper(const envoy::api::v2::auth::Secret&) {} + virtual void updateConfigHelper(const envoy::api::v2::auth::Secret&) PURE; uint64_t secret_hash_; + std::unique_ptr secrets_; + Common::CallbackManager<> update_callback_manager_; private: void runInitializeCallbackIfAny(); @@ -67,7 +78,7 @@ class SdsApi : public Init::Target, /** * TlsCertificateSdsApi implementation maintains and updates dynamic TLS certificate secrets. */ -class TlsCertificateSdsApi : public SdsApi, public TlsCertificateConfigProvider { +class TlsCertificateSdsApi : public SdsApi { public: TlsCertificateSdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, Stats::Store& stats, @@ -77,29 +88,16 @@ class TlsCertificateSdsApi : public SdsApi, public TlsCertificateConfigProvider : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, sds_config_name, destructor_cb) {} - // SecretProvider - const Ssl::TlsCertificateConfig* secret() const override { - return tls_certificate_secrets_.get(); - } - - Common::CallbackHandle* addUpdateCallback(std::function callback) override { - return update_callback_manager_.add(callback); - } - private: // SdsApi void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; - - Ssl::TlsCertificateConfigPtr tls_certificate_secrets_; - Common::CallbackManager<> update_callback_manager_; }; /** * CertificateValidationContextSdsApi implementation maintains and updates dynamic certificate * validation context secrets. */ -class CertificateValidationContextSdsApi : public SdsApi, - public CertificateValidationContextConfigProvider { +class CertificateValidationContextSdsApi : public SdsApi { public: CertificateValidationContextSdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, @@ -112,21 +110,9 @@ class CertificateValidationContextSdsApi : public SdsApi, : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, sds_config_name, destructor_cb) {} - // SecretProvider - const Ssl::CertificateValidationContextConfig* secret() const override { - return certificate_validation_context_secrets_.get(); - } - - Common::CallbackHandle* addUpdateCallback(std::function callback) override { - return update_callback_manager_.add(callback); - } - private: // SdsApi void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; - - Ssl::CertificateValidationContextConfigPtr certificate_validation_context_secrets_; - Common::CallbackManager<> update_callback_manager_; }; } // namespace Secret diff --git a/source/common/ssl/context_config_impl.cc b/source/common/ssl/context_config_impl.cc index aef7c1bdeec19..e6ecfa64189f4 100644 --- a/source/common/ssl/context_config_impl.cc +++ b/source/common/ssl/context_config_impl.cc @@ -65,6 +65,9 @@ getCertificateValidationContextConfigProvider( sds_secret_config.name())); } return secret_provider; + } else { + return factory_context.secretManager().findOrCreateCertificateValidationContextProvider( + sds_secret_config.sds_config(), sds_secret_config.name(), factory_context); } } return nullptr; @@ -98,17 +101,21 @@ ContextConfigImpl::ContextConfigImpl( ecdh_curves_(StringUtil::nonEmptyStringOrDefault( RepeatedPtrUtil::join(config.tls_params().ecdh_curves(), ":"), DEFAULT_ECDH_CURVES)), tls_certficate_provider_(getTlsCertificateConfigProvider(config, factory_context)), - secret_update_callback_handle_(nullptr), + tls_certificate_update_callback_handle_(nullptr), certficate_validation_context_provider_( getCertificateValidationContextConfigProvider(config, factory_context)), + certificate_validation_context_update_callback_handle_(nullptr), min_protocol_version_( tlsVersionFromProto(config.tls_params().tls_minimum_protocol_version(), TLS1_VERSION)), max_protocol_version_(tlsVersionFromProto(config.tls_params().tls_maximum_protocol_version(), TLS1_2_VERSION)) {} ContextConfigImpl::~ContextConfigImpl() { - if (secret_update_callback_handle_) { - secret_update_callback_handle_->remove(); + if (tls_certificate_update_callback_handle_) { + tls_certificate_update_callback_handle_->remove(); + } + if (certificate_validation_context_update_callback_handle_) { + certificate_validation_context_update_callback_handle_->remove(); } } diff --git a/source/common/ssl/context_config_impl.h b/source/common/ssl/context_config_impl.h index 7e6443be9a090..90e048cc8495e 100644 --- a/source/common/ssl/context_config_impl.h +++ b/source/common/ssl/context_config_impl.h @@ -39,16 +39,28 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { bool isReady() const override { // Either tls_certficate_provider_ is nullptr or - // tls_certficate_provider_->secret() is NOT nullptr. - return !tls_certficate_provider_ || tls_certficate_provider_->secret() != nullptr; + // tls_certficate_provider_->secret() is NOT nullptr and + // either certficate_validation_context_provider_ is nullptr or + // certficate_validation_context_provider_->secret() is NOT nullptr. + return (!tls_certficate_provider_ || tls_certficate_provider_->secret() != nullptr) && + (!certficate_validation_context_provider_ || + certficate_validation_context_provider_->secret() != nullptr); } void setSecretUpdateCallback(std::function callback) override { if (tls_certficate_provider_) { - if (secret_update_callback_handle_) { - secret_update_callback_handle_->remove(); + if (tls_certificate_update_callback_handle_) { + tls_certificate_update_callback_handle_->remove(); } - secret_update_callback_handle_ = tls_certficate_provider_->addUpdateCallback(callback); + tls_certificate_update_callback_handle_ = + tls_certficate_provider_->addUpdateCallback(callback); + } + if (certficate_validation_context_provider_) { + if (certificate_validation_context_update_callback_handle_) { + certificate_validation_context_update_callback_handle_->remove(); + } + certificate_validation_context_update_callback_handle_ = + certficate_validation_context_provider_->addUpdateCallback(callback); } } @@ -69,9 +81,10 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { const std::string cipher_suites_; const std::string ecdh_curves_; Secret::TlsCertificateConfigProviderSharedPtr tls_certficate_provider_; - Common::CallbackHandle* secret_update_callback_handle_; + Common::CallbackHandle* tls_certificate_update_callback_handle_; Secret::CertificateValidationContextConfigProviderSharedPtr certficate_validation_context_provider_; + Common::CallbackHandle* certificate_validation_context_update_callback_handle_; const unsigned min_protocol_version_; const unsigned max_protocol_version_; }; From 61bdf6c6c757c8eb9b4c5801f011fa99858c60ed Mon Sep 17 00:00:00 2001 From: JimmyCYJ Date: Wed, 5 Sep 2018 15:17:54 -0700 Subject: [PATCH 09/15] pass sds config name as const std::string&. Signed-off-by: JimmyCYJ --- source/common/secret/sds_api.cc | 2 +- source/common/secret/sds_api.h | 2 +- source/common/secret/secret_manager_impl.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/common/secret/sds_api.cc b/source/common/secret/sds_api.cc index a13afc8e5c9bf..8119dbebabfc9 100644 --- a/source/common/secret/sds_api.cc +++ b/source/common/secret/sds_api.cc @@ -18,7 +18,7 @@ SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispat Runtime::RandomGenerator& random, Stats::Store& stats, Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, const envoy::api::v2::core::ConfigSource& sds_config, - std::string sds_config_name, std::function destructor_cb) + const std::string& sds_config_name, std::function destructor_cb) : secret_hash_(0), local_info_(local_info), dispatcher_(dispatcher), random_(random), stats_(stats), cluster_manager_(cluster_manager), sds_config_(sds_config), sds_config_name_(sds_config_name), clean_up_(destructor_cb) { diff --git a/source/common/secret/sds_api.h b/source/common/secret/sds_api.h index 6d4bb1d1cf8e8..db88ed4943328 100644 --- a/source/common/secret/sds_api.h +++ b/source/common/secret/sds_api.h @@ -31,7 +31,7 @@ class SdsApi : public Init::Target, SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, Stats::Store& stats, Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, - const envoy::api::v2::core::ConfigSource& sds_config, std::string sds_config_name, + const envoy::api::v2::core::ConfigSource& sds_config, const std::string& sds_config_name, std::function destructor_cb); // Init::Target diff --git a/source/common/secret/secret_manager_impl.cc b/source/common/secret/secret_manager_impl.cc index 09e701982b4ea..111094ee1922f 100644 --- a/source/common/secret/secret_manager_impl.cc +++ b/source/common/secret/secret_manager_impl.cc @@ -103,7 +103,7 @@ void SecretManagerImpl::removeDynamicCertificateValidationContextProvider( ENVOY_LOG(debug, "Unregister certificate validation context provider. hash key: {}", map_key); auto num_deleted = dynamic_certificate_validation_context_providers_.erase(map_key); - ASSERT(num_deleted == 1, ""); + ASSERT(num_deleted == 1); } CertificateValidationContextConfigProviderSharedPtr From cf3cbf3b291205458bb867d9324584efbfde2f3f Mon Sep 17 00:00:00 2001 From: JimmyCYJ Date: Wed, 5 Sep 2018 15:32:28 -0700 Subject: [PATCH 10/15] Add explicit instantiation. Signed-off-by: JimmyCYJ --- source/common/secret/sds_api.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/source/common/secret/sds_api.cc b/source/common/secret/sds_api.cc index 8119dbebabfc9..1da2c4de31034 100644 --- a/source/common/secret/sds_api.cc +++ b/source/common/secret/sds_api.cc @@ -13,6 +13,9 @@ namespace Envoy { namespace Secret { +template class Secret::SdsApi; +template class Secret::SdsApi; + template SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, Stats::Store& stats, From a905c0a69009def4949ad9278f101e61257c5f9a Mon Sep 17 00:00:00 2001 From: JimmyCYJ Date: Wed, 5 Sep 2018 18:18:36 -0700 Subject: [PATCH 11/15] Update tests. Signed-off-by: JimmyCYJ --- test/common/secret/sds_api_test.cc | 25 +++++++++++-------- .../common/secret/secret_manager_impl_test.cc | 2 +- test/mocks/secret/mocks.h | 5 ++++ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/test/common/secret/sds_api_test.cc b/test/common/secret/sds_api_test.cc index 17435fe5c8d80..084faace73c90 100644 --- a/test/common/secret/sds_api_test.cc +++ b/test/common/secret/sds_api_test.cc @@ -41,8 +41,9 @@ TEST_F(SdsApiTest, BasicTest) { auto google_grpc = grpc_service->mutable_google_grpc(); google_grpc->set_target_uri("fake_address"); google_grpc->set_stat_prefix("test"); - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); NiceMock* grpc_client{new NiceMock()}; NiceMock* factory{new NiceMock()}; @@ -62,8 +63,9 @@ TEST_F(SdsApiTest, SecretUpdateSuccess) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); NiceMock secret_callback; auto handle = @@ -101,8 +103,9 @@ TEST_F(SdsApiTest, EmptyResource) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); Protobuf::RepeatedPtrField secret_resources; @@ -115,8 +118,9 @@ TEST_F(SdsApiTest, SecretUpdateWrongSize) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); std::string yaml = R"EOF( @@ -144,8 +148,9 @@ TEST_F(SdsApiTest, SecretUpdateWrongSecretName) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; - SdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), server.stats(), - server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + TlsCertificateSdsApi sds_api(server.localInfo(), server.dispatcher(), server.random(), + server.stats(), server.clusterManager(), init_manager, config_source, + "abc.com", []() {}); std::string yaml = R"EOF( diff --git a/test/common/secret/secret_manager_impl_test.cc b/test/common/secret/secret_manager_impl_test.cc index c340e63342f47..4432d9e94e364 100644 --- a/test/common/secret/secret_manager_impl_test.cc +++ b/test/common/secret/secret_manager_impl_test.cc @@ -171,7 +171,7 @@ name: "abc.com" Protobuf::RepeatedPtrField secret_resources; auto secret_config = secret_resources.Add(); MessageUtil::loadFromYaml(TestEnvironment::substitute(yaml), *secret_config); - dynamic_cast(*secret_provider).onConfigUpdate(secret_resources, ""); + dynamic_cast(*secret_provider).onConfigUpdate(secret_resources, ""); const std::string cert_pem = "{{ test_rundir }}/test/common/ssl/test_data/selfsigned_cert.pem"; EXPECT_EQ(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(cert_pem)), secret_provider->secret()->certificateChain()); diff --git a/test/mocks/secret/mocks.h b/test/mocks/secret/mocks.h index 2637d6d850325..428f1ec28faad 100644 --- a/test/mocks/secret/mocks.h +++ b/test/mocks/secret/mocks.h @@ -32,6 +32,11 @@ class MockSecretManager : public SecretManager { TlsCertificateConfigProviderSharedPtr( const envoy::api::v2::core::ConfigSource&, const std::string&, Server::Configuration::TransportSocketFactoryContext&)); + MOCK_METHOD3(findOrCreateCertificateValidationContextProvider, + CertificateValidationContextConfigProviderSharedPtr( + const envoy::api::v2::core::ConfigSource& config_source, + const std::string& config_name, + Server::Configuration::TransportSocketFactoryContext& secret_provider_context)); }; class MockSecretCallbacks : public SecretCallbacks { From 7707c3ba095da6c4bd55afb4a064907443792c7c Mon Sep 17 00:00:00 2001 From: "tianqian.zyf" <445188383@qq.com> Date: Thu, 6 Sep 2018 11:50:12 +0800 Subject: [PATCH 12/15] fix time_since_epoch different in different os default return precision (#4288) Risk Level: low Testing: N/A Docs Changes: Release Notes: Fixes #4278 Signed-off-by: tianqian.zyf --- .../stat_sinks/metrics_service/config.cc | 2 +- .../grpc_metrics_service_impl.cc | 17 ++++++++++++----- .../metrics_service/grpc_metrics_service_impl.h | 4 +++- .../grpc_metrics_service_impl_test.cc | 7 +++++-- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/source/extensions/stat_sinks/metrics_service/config.cc b/source/extensions/stat_sinks/metrics_service/config.cc index 234b46b377406..22e9e58fad1ff 100644 --- a/source/extensions/stat_sinks/metrics_service/config.cc +++ b/source/extensions/stat_sinks/metrics_service/config.cc @@ -29,7 +29,7 @@ Stats::SinkPtr MetricsServiceSinkFactory::createStatsSink(const Protobuf::Messag grpc_service, server.stats(), false), server.threadLocal(), server.localInfo()); - return std::make_unique(grpc_metrics_streamer); + return std::make_unique(grpc_metrics_streamer, server.timeSource()); } ProtobufTypes::MessagePtr MetricsServiceSinkFactory::createEmptyConfigProto() { diff --git a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc index ce68afd5f98aa..b218a57f1c74e 100644 --- a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc +++ b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.cc @@ -61,15 +61,18 @@ void GrpcMetricsStreamerImpl::ThreadLocalStreamer::send( } } -MetricsServiceSink::MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer) - : grpc_metrics_streamer_(grpc_metrics_streamer) {} +MetricsServiceSink::MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer, + TimeSource& time_source) + : grpc_metrics_streamer_(grpc_metrics_streamer), time_source_(time_source) {} void MetricsServiceSink::flushCounter(const Stats::Counter& counter) { io::prometheus::client::MetricFamily* metrics_family = message_.add_envoy_metrics(); metrics_family->set_type(io::prometheus::client::MetricType::COUNTER); metrics_family->set_name(counter.name()); auto* metric = metrics_family->add_metric(); - metric->set_timestamp_ms(std::chrono::system_clock::now().time_since_epoch().count()); + metric->set_timestamp_ms(std::chrono::duration_cast( + time_source_.systemTime().time_since_epoch()) + .count()); auto* counter_metric = metric->mutable_counter(); counter_metric->set_value(counter.value()); } @@ -79,7 +82,9 @@ void MetricsServiceSink::flushGauge(const Stats::Gauge& gauge) { metrics_family->set_type(io::prometheus::client::MetricType::GAUGE); metrics_family->set_name(gauge.name()); auto* metric = metrics_family->add_metric(); - metric->set_timestamp_ms(std::chrono::system_clock::now().time_since_epoch().count()); + metric->set_timestamp_ms(std::chrono::duration_cast( + time_source_.systemTime().time_since_epoch()) + .count()); auto* gauage_metric = metric->mutable_gauge(); gauage_metric->set_value(gauge.value()); } @@ -88,7 +93,9 @@ void MetricsServiceSink::flushHistogram(const Stats::ParentHistogram& histogram) metrics_family->set_type(io::prometheus::client::MetricType::SUMMARY); metrics_family->set_name(histogram.name()); auto* metric = metrics_family->add_metric(); - metric->set_timestamp_ms(std::chrono::system_clock::now().time_since_epoch().count()); + metric->set_timestamp_ms(std::chrono::duration_cast( + time_source_.systemTime().time_since_epoch()) + .count()); auto* summary_metric = metric->mutable_summary(); const Stats::HistogramStatistics& hist_stats = histogram.intervalStatistics(); for (size_t i = 0; i < hist_stats.supportedQuantiles().size(); i++) { diff --git a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h index b27792652d788..bb796052a311b 100644 --- a/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h +++ b/source/extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h @@ -111,7 +111,8 @@ class GrpcMetricsStreamerImpl : public Singleton::Instance, public GrpcMetricsSt class MetricsServiceSink : public Stats::Sink { public: // MetricsService::Sink - MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer); + MetricsServiceSink(const GrpcMetricsStreamerSharedPtr& grpc_metrics_streamer, + TimeSource& time_source); void flush(Stats::Source& source) override; void onHistogramComplete(const Stats::Histogram&, uint64_t) override {} @@ -122,6 +123,7 @@ class MetricsServiceSink : public Stats::Sink { private: GrpcMetricsStreamerSharedPtr grpc_metrics_streamer_; envoy::service::metrics::v2::StreamMetricsMessage message_; + TimeSource& time_source_; }; } // namespace MetricsService diff --git a/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc b/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc index f2996770d3a26..627eabf08dab0 100644 --- a/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc +++ b/test/extensions/stats_sinks/metrics_service/grpc_metrics_service_impl_test.cc @@ -1,5 +1,6 @@ #include "extensions/stat_sinks/metrics_service/grpc_metrics_service_impl.h" +#include "test/mocks/common.h" #include "test/mocks/grpc/mocks.h" #include "test/mocks/local_info/mocks.h" #include "test/mocks/stats/mocks.h" @@ -98,9 +99,10 @@ class MetricsServiceSinkTest : public testing::Test {}; TEST(MetricsServiceSinkTest, CheckSendCall) { NiceMock source; + NiceMock mock_time; std::shared_ptr streamer_{new MockGrpcMetricsStreamer()}; - MetricsServiceSink sink(streamer_); + MetricsServiceSink sink(streamer_, mock_time); auto counter = std::make_shared>(); counter->name_ = "test_counter"; @@ -125,9 +127,10 @@ TEST(MetricsServiceSinkTest, CheckSendCall) { TEST(MetricsServiceSinkTest, CheckStatsCount) { NiceMock source; + NiceMock mock_time; std::shared_ptr streamer_{new TestGrpcMetricsStreamer()}; - MetricsServiceSink sink(streamer_); + MetricsServiceSink sink(streamer_, mock_time); auto counter = std::make_shared>(); counter->name_ = "test_counter"; From 1537e6828e5652ee1de5cdc7aae70ca3ce448091 Mon Sep 17 00:00:00 2001 From: Stephan Zuercher Date: Thu, 6 Sep 2018 06:26:40 -0700 Subject: [PATCH 13/15] test: another echo_integration_test fix (#4350) OS X's localhost interface often takes longer to fail to connect than Linux. Modifies RawConnectionDriver to detect when its underlying connection is connected or failed and pauses the echo_integration_test until that state is reached before checking the connection state. Risk Level: low, test-only changes Testing: existing tests Doc Changes: n/a Release Notes: n/a Signed-off-by: Stephan Zuercher stephan@turbinelabs.io --- test/integration/echo_integration_test.cc | 7 ++++++- test/integration/utility.cc | 2 ++ test/integration/utility.h | 10 ++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/integration/echo_integration_test.cc b/test/integration/echo_integration_test.cc index 1eef8fb6a0ebb..d38c2eb131569 100644 --- a/test/integration/echo_integration_test.cc +++ b/test/integration/echo_integration_test.cc @@ -130,7 +130,12 @@ TEST_P(EchoIntegrationTest, AddRemoveListener) { RawConnectionDriver connection2( new_listener_port, buffer, [&](Network::ClientConnection&, const Buffer::Instance&) -> void { FAIL(); }, version_); - connection2.run(Event::Dispatcher::RunType::NonBlock); + while (connection2.connecting()) { + // Don't busy loop, but OS X often needs a moment to decide this connection isn't happening. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + connection2.run(Event::Dispatcher::RunType::NonBlock); + } if (connection2.connection().state() == Network::Connection::State::Closed) { connect_fail = true; break; diff --git a/test/integration/utility.cc b/test/integration/utility.cc index ff815315c6531..9e136407254fd 100644 --- a/test/integration/utility.cc +++ b/test/integration/utility.cc @@ -111,10 +111,12 @@ RawConnectionDriver::RawConnectionDriver(uint32_t port, Buffer::Instance& initia Network::Address::IpVersion version) { api_.reset(new Api::Impl(std::chrono::milliseconds(10000))); dispatcher_ = api_->allocateDispatcher(IntegrationUtil::evil_singleton_test_time_.timeSource()); + callbacks_ = std::make_unique(); client_ = dispatcher_->createClientConnection( Network::Utility::resolveUrl( fmt::format("tcp://{}:{}", Network::Test::getLoopbackAddressUrlString(version), port)), Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket(), nullptr); + client_->addConnectionCallbacks(*callbacks_); client_->addReadFilter(Network::ReadFilterSharedPtr{new ForwardingFilter(*this, data_callback)}); client_->write(initial_data, false); client_->connect(); diff --git a/test/integration/utility.h b/test/integration/utility.h index a1bd07f5ece94..04e5d650eebb3 100644 --- a/test/integration/utility.h +++ b/test/integration/utility.h @@ -62,6 +62,7 @@ class RawConnectionDriver { Network::Address::IpVersion version); ~RawConnectionDriver(); const Network::Connection& connection() { return *client_; } + bool connecting() { return callbacks_->connecting_; } void run(Event::Dispatcher::RunType run_type = Event::Dispatcher::RunType::Block); void close(); @@ -81,8 +82,17 @@ class RawConnectionDriver { ReadCallback data_callback_; }; + struct ConnectionCallbacks : public Network::ConnectionCallbacks { + void onEvent(Network::ConnectionEvent) override { connecting_ = false; } + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + bool connecting_{true}; + }; + Api::ApiPtr api_; Event::DispatcherPtr dispatcher_; + std::unique_ptr callbacks_; Network::ClientConnectionPtr client_; }; From c4211b3195c885a1547070f02acdea0eb88e316a Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Thu, 6 Sep 2018 10:06:09 -0700 Subject: [PATCH 14/15] not to writeQueued if cq is drained (#4356) Risk Level: Low Testing: bazel test --runs_per_test=1000 //test/integration:sds_dynamic_integration_test Signed-off-by: Wayne Zhang --- source/common/grpc/google_async_client_impl.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/common/grpc/google_async_client_impl.cc b/source/common/grpc/google_async_client_impl.cc index 7fdc57768d6c5..7a929a8f24f92 100644 --- a/source/common/grpc/google_async_client_impl.cc +++ b/source/common/grpc/google_async_client_impl.cc @@ -205,7 +205,8 @@ void GoogleAsyncStreamImpl::resetStream() { } void GoogleAsyncStreamImpl::writeQueued() { - if (!call_initialized_ || finish_pending_ || write_pending_ || write_pending_queue_.empty()) { + if (!call_initialized_ || finish_pending_ || write_pending_ || write_pending_queue_.empty() || + draining_cq_) { return; } write_pending_ = true; From a12b042c5b6a896fe9e79864c0c516da7097449b Mon Sep 17 00:00:00 2001 From: JimmyCYJ Date: Thu, 6 Sep 2018 17:23:44 -0700 Subject: [PATCH 15/15] Add test and refactor SdsApi. Signed-off-by: JimmyCYJ --- source/common/secret/sds_api.cc | 28 ++++++------- source/common/secret/sds_api.h | 37 +++++++++++------ source/common/secret/secret_manager_impl.cc | 30 +++++--------- source/common/secret/secret_manager_impl.h | 16 +++----- test/common/secret/sds_api_test.cc | 40 ++++++++++++++++++- .../sds_static_integration_test.cc | 12 ++++-- 6 files changed, 98 insertions(+), 65 deletions(-) diff --git a/source/common/secret/sds_api.cc b/source/common/secret/sds_api.cc index 1da2c4de31034..1678df0828d99 100644 --- a/source/common/secret/sds_api.cc +++ b/source/common/secret/sds_api.cc @@ -13,15 +13,11 @@ namespace Envoy { namespace Secret { -template class Secret::SdsApi; -template class Secret::SdsApi; - -template -SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, - Runtime::RandomGenerator& random, Stats::Store& stats, - Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, - const envoy::api::v2::core::ConfigSource& sds_config, - const std::string& sds_config_name, std::function destructor_cb) +SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, + Runtime::RandomGenerator& random, Stats::Store& stats, + Upstream::ClusterManager& cluster_manager, Init::Manager& init_manager, + const envoy::api::v2::core::ConfigSource& sds_config, + const std::string& sds_config_name, std::function destructor_cb) : secret_hash_(0), local_info_(local_info), dispatcher_(dispatcher), random_(random), stats_(stats), cluster_manager_(cluster_manager), sds_config_(sds_config), sds_config_name_(sds_config_name), clean_up_(destructor_cb) { @@ -32,7 +28,7 @@ SdsApi::SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispat init_manager.registerTarget(*this); } -template void SdsApi::initialize(std::function callback) { +void SdsApi::initialize(std::function callback) { initialize_callback_ = callback; subscription_ = Envoy::Config::SubscriptionFactory::subscriptionFromConfigSource< @@ -46,8 +42,7 @@ template void SdsApi::initialize(std::functionstart({sds_config_name_}, *this); } -template -void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) { +void SdsApi::onConfigUpdate(const ResourceVector& resources, const std::string&) { if (resources.empty()) { throw EnvoyException( fmt::format("Missing SDS resources for {} in onConfigUpdate()", sds_config_name_)); @@ -70,12 +65,12 @@ void SdsApi::onConfigUpdate(const ResourceVector& resources, const s runInitializeCallbackIfAny(); } -template void SdsApi::onConfigUpdateFailed(const EnvoyException*) { +void SdsApi::onConfigUpdateFailed(const EnvoyException*) { // We need to allow server startup to continue, even if we have a bad config. runInitializeCallbackIfAny(); } -template void SdsApi::runInitializeCallbackIfAny() { +void SdsApi::runInitializeCallbackIfAny() { if (initialize_callback_) { initialize_callback_(); initialize_callback_ = nullptr; @@ -87,7 +82,8 @@ void TlsCertificateSdsApi::updateConfigHelper(const envoy::api::v2::auth::Secret if (new_hash != secret_hash_ && secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kTlsCertificate) { secret_hash_ = new_hash; - secrets_ = std::make_unique(secret.tls_certificate()); + tls_certificate_secrets_ = + std::make_unique(secret.tls_certificate()); update_callback_manager_.runCallbacks(); } @@ -99,7 +95,7 @@ void CertificateValidationContextSdsApi::updateConfigHelper( if (new_hash != secret_hash_ && secret.type_case() == envoy::api::v2::auth::Secret::TypeCase::kValidationContext) { secret_hash_ = new_hash; - secrets_ = + certificate_validation_context_secrets_ = std::make_unique(secret.validation_context()); update_callback_manager_.runCallbacks(); diff --git a/source/common/secret/sds_api.h b/source/common/secret/sds_api.h index db88ed4943328..23f31a4100ea7 100644 --- a/source/common/secret/sds_api.h +++ b/source/common/secret/sds_api.h @@ -23,9 +23,7 @@ namespace Secret { /** * SDS API implementation that fetches secrets from SDS server via Subscription. */ -template class SdsApi : public Init::Target, - public SecretProvider, public Config::SubscriptionCallbacks { public: SdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, @@ -44,18 +42,10 @@ class SdsApi : public Init::Target, return MessageUtil::anyConvert(resource).name(); } - // SecretProvider - const SecretType* secret() const override { return secrets_.get(); } - - Common::CallbackHandle* addUpdateCallback(std::function callback) override { - return update_callback_manager_.add(callback); - } - protected: // Updates local storage of dynamic secrets and invokes callbacks. virtual void updateConfigHelper(const envoy::api::v2::auth::Secret&) PURE; uint64_t secret_hash_; - std::unique_ptr secrets_; Common::CallbackManager<> update_callback_manager_; private: @@ -75,10 +65,12 @@ class SdsApi : public Init::Target, Cleanup clean_up_; }; +typedef std::shared_ptr SdsApiSharedPtr; + /** * TlsCertificateSdsApi implementation maintains and updates dynamic TLS certificate secrets. */ -class TlsCertificateSdsApi : public SdsApi { +class TlsCertificateSdsApi : public SdsApi, public TlsCertificateConfigProvider { public: TlsCertificateSdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, Stats::Store& stats, @@ -88,16 +80,27 @@ class TlsCertificateSdsApi : public SdsApi { : SdsApi(local_info, dispatcher, random, stats, cluster_manager, init_manager, sds_config, sds_config_name, destructor_cb) {} + // SecretProvider + const Ssl::TlsCertificateConfig* secret() const override { + return tls_certificate_secrets_.get(); + } + Common::CallbackHandle* addUpdateCallback(std::function callback) override { + return update_callback_manager_.add(callback); + } + private: // SdsApi void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; + + Ssl::TlsCertificateConfigPtr tls_certificate_secrets_; }; /** * CertificateValidationContextSdsApi implementation maintains and updates dynamic certificate * validation context secrets. */ -class CertificateValidationContextSdsApi : public SdsApi { +class CertificateValidationContextSdsApi : public SdsApi, + public CertificateValidationContextConfigProvider { public: CertificateValidationContextSdsApi(const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, @@ -110,9 +113,19 @@ class CertificateValidationContextSdsApi : public SdsApi callback) override { + return update_callback_manager_.add(callback); + } + private: // SdsApi void updateConfigHelper(const envoy::api::v2::auth::Secret& secret) override; + + Ssl::CertificateValidationContextConfigPtr certificate_validation_context_secrets_; }; } // namespace Secret diff --git a/source/common/secret/secret_manager_impl.cc b/source/common/secret/secret_manager_impl.cc index 111094ee1922f..43fe82126fc06 100644 --- a/source/common/secret/secret_manager_impl.cc +++ b/source/common/secret/secret_manager_impl.cc @@ -64,10 +64,10 @@ SecretManagerImpl::createInlineCertificateValidationContextProvider( certificate_validation_context); } -void SecretManagerImpl::removeDynamicTlsCertificateProvider(const std::string& map_key) { +void SecretManagerImpl::removeDynamicSecretProvider(const std::string& map_key) { ENVOY_LOG(debug, "Unregister tls certificate provider. hash key: {}", map_key); - auto num_deleted = dynamic_tls_certificate_providers_.erase(map_key); + auto num_deleted = dynamic_secret_providers_.erase(map_key); ASSERT(num_deleted == 1, ""); } @@ -76,15 +76,14 @@ TlsCertificateConfigProviderSharedPtr SecretManagerImpl::findOrCreateTlsCertific Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { const std::string map_key = sds_config_source.SerializeAsString() + config_name; - TlsCertificateConfigProviderSharedPtr secret_provider = - dynamic_tls_certificate_providers_[map_key].lock(); + SdsApiSharedPtr secret_provider = dynamic_secret_providers_[map_key].lock(); if (!secret_provider) { ASSERT(secret_provider_context.initManager() != nullptr); // SdsApi is owned by ListenerImpl and ClusterInfo which are destroyed before // SecretManagerImpl. It is safe to invoke this callback at the destructor of SdsApi. std::function unregister_secret_provider = [map_key, this]() { - removeDynamicTlsCertificateProvider(map_key); + removeDynamicSecretProvider(map_key); }; secret_provider = std::make_shared( @@ -92,18 +91,10 @@ TlsCertificateConfigProviderSharedPtr SecretManagerImpl::findOrCreateTlsCertific secret_provider_context.random(), secret_provider_context.stats(), secret_provider_context.clusterManager(), *secret_provider_context.initManager(), sds_config_source, config_name, unregister_secret_provider); - dynamic_tls_certificate_providers_[map_key] = secret_provider; + dynamic_secret_providers_[map_key] = secret_provider; } - return secret_provider; -} - -void SecretManagerImpl::removeDynamicCertificateValidationContextProvider( - const std::string& map_key) { - ENVOY_LOG(debug, "Unregister certificate validation context provider. hash key: {}", map_key); - - auto num_deleted = dynamic_certificate_validation_context_providers_.erase(map_key); - ASSERT(num_deleted == 1); + return std::dynamic_pointer_cast(secret_provider); } CertificateValidationContextConfigProviderSharedPtr @@ -112,15 +103,14 @@ SecretManagerImpl::findOrCreateCertificateValidationContextProvider( Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { const std::string map_key = sds_config_source.SerializeAsString() + config_name; - CertificateValidationContextConfigProviderSharedPtr secret_provider = - dynamic_certificate_validation_context_providers_[map_key].lock(); + SdsApiSharedPtr secret_provider = dynamic_secret_providers_[map_key].lock(); if (!secret_provider) { ASSERT(secret_provider_context.initManager() != nullptr); // SdsApi is owned by ListenerImpl and ClusterInfo which are destroyed before // SecretManagerImpl. It is safe to invoke this callback at the destructor of SdsApi. std::function unregister_secret_provider = [map_key, this]() { - removeDynamicCertificateValidationContextProvider(map_key); + removeDynamicSecretProvider(map_key); }; secret_provider = std::make_shared( @@ -128,10 +118,10 @@ SecretManagerImpl::findOrCreateCertificateValidationContextProvider( secret_provider_context.random(), secret_provider_context.stats(), secret_provider_context.clusterManager(), *secret_provider_context.initManager(), sds_config_source, config_name, unregister_secret_provider); - dynamic_certificate_validation_context_providers_[map_key] = secret_provider; + dynamic_secret_providers_[map_key] = secret_provider; } - return secret_provider; + return std::dynamic_pointer_cast(secret_provider); } } // namespace Secret diff --git a/source/common/secret/secret_manager_impl.h b/source/common/secret/secret_manager_impl.h index 97c18297fdc96..20e2ac261b33c 100644 --- a/source/common/secret/secret_manager_impl.h +++ b/source/common/secret/secret_manager_impl.h @@ -9,6 +9,7 @@ #include "envoy/ssl/tls_certificate_config.h" #include "common/common/logger.h" +#include "common/secret/sds_api.h" namespace Envoy { namespace Secret { @@ -41,10 +42,8 @@ class SecretManagerImpl : public SecretManager, Logger::Loggable @@ -54,13 +53,8 @@ class SecretManagerImpl : public SecretManager, Logger::Loggable static_certificate_validation_context_providers_; - // map hash code of SDS config source and TlsCertificateSdsApi object. - std::unordered_map> - dynamic_tls_certificate_providers_; - - // map hash code of SDS config source and CertificateValidationContextSdsApi object. - std::unordered_map> - dynamic_certificate_validation_context_providers_; + // map hash code of SDS config source and SdsApi object. + std::unordered_map> dynamic_secret_providers_; }; } // namespace Secret diff --git a/test/common/secret/sds_api_test.cc b/test/common/secret/sds_api_test.cc index 084faace73c90..7a4432e4ff744 100644 --- a/test/common/secret/sds_api_test.cc +++ b/test/common/secret/sds_api_test.cc @@ -58,8 +58,9 @@ TEST_F(SdsApiTest, BasicTest) { init_manager.initialize(); } -// Validate that SdsApi updates secrets successfully if a good secret is passed to onConfigUpdate(). -TEST_F(SdsApiTest, SecretUpdateSuccess) { +// Validate that TlsCertificateSdsApi updates secrets successfully if a good secret +// is passed to onConfigUpdate(). +TEST_F(SdsApiTest, DynamicTlsCertificateUpdateSuccess) { NiceMock server; NiceMock init_manager; envoy::api::v2::core::ConfigSource config_source; @@ -98,6 +99,41 @@ TEST_F(SdsApiTest, SecretUpdateSuccess) { handle->remove(); } +// Validate that CertificateValidationContextSdsApi updates secrets successfully if +// a good secret is passed to onConfigUpdate(). +TEST_F(SdsApiTest, DynamicCertificateValidationContextUpdateSuccess) { + NiceMock server; + NiceMock init_manager; + envoy::api::v2::core::ConfigSource config_source; + CertificateValidationContextSdsApi sds_api( + server.localInfo(), server.dispatcher(), server.random(), server.stats(), + server.clusterManager(), init_manager, config_source, "abc.com", []() {}); + + NiceMock secret_callback; + auto handle = + sds_api.addUpdateCallback([&secret_callback]() { secret_callback.onAddOrUpdateSecret(); }); + + std::string yaml = + R"EOF( + name: "abc.com" + validation_context: + trusted_ca: { filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" } + allow_expired_certificate: true + )EOF"; + + Protobuf::RepeatedPtrField secret_resources; + auto secret_config = secret_resources.Add(); + MessageUtil::loadFromYaml(TestEnvironment::substitute(yaml), *secret_config); + EXPECT_CALL(secret_callback, onAddOrUpdateSecret()); + sds_api.onConfigUpdate(secret_resources, ""); + + const std::string ca_cert = "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem"; + EXPECT_EQ(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(ca_cert)), + sds_api.secret()->caCert()); + + handle->remove(); +} + // Validate that SdsApi throws exception if an empty secret is passed to onConfigUpdate(). TEST_F(SdsApiTest, EmptyResource) { NiceMock server; diff --git a/test/integration/sds_static_integration_test.cc b/test/integration/sds_static_integration_test.cc index cc259d31ffe7f..f7551a751cd2e 100644 --- a/test/integration/sds_static_integration_test.cc +++ b/test/integration/sds_static_integration_test.cc @@ -47,16 +47,20 @@ class SdsStaticDownstreamIntegrationTest ->mutable_common_tls_context(); common_tls_context->add_alpn_protocols("http/1.1"); - auto* validation_context = common_tls_context->mutable_validation_context(); + common_tls_context->mutable_validation_context_sds_secret_config()->set_name( + "validation_context"); + common_tls_context->add_tls_certificate_sds_secret_configs()->set_name("server_cert"); + + auto* secret = bootstrap.mutable_static_resources()->add_secrets(); + secret->set_name("validation_context"); + auto* validation_context = secret->mutable_validation_context(); validation_context->mutable_trusted_ca()->set_filename( TestEnvironment::runfilesPath("test/config/integration/certs/cacert.pem")); validation_context->add_verify_certificate_hash( "E0:F3:C8:CE:5E:2E:A3:05:F0:70:1F:F5:12:E3:6E:2E:" "97:92:82:84:A2:28:BC:F7:73:32:D3:39:30:A1:B6:FD"); - common_tls_context->add_tls_certificate_sds_secret_configs()->set_name("server_cert"); - - auto* secret = bootstrap.mutable_static_resources()->add_secrets(); + secret = bootstrap.mutable_static_resources()->add_secrets(); secret->set_name("server_cert"); auto* tls_certificate = secret->mutable_tls_certificate(); tls_certificate->mutable_certificate_chain()->set_filename(