diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index 81ef7c6c3eec2..00b03293101a5 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -40,6 +40,7 @@ Bug Fixes * access log: fix ``%UPSTREAM_CLUSTER%`` when used in http upstream access logs. Previously, it was always logging as an unset value. * access log: fix ``%UPSTREAM_CLUSTER%`` when used in http upstream access logs. Previously, it was always logging as an unset value. +* active health checks: health checks using a TLS transport socket and secrets delivered via :ref:`SDS ` will now wait until secrets are loaded before the first health check attempt. This should improve startup times by not having to wait for the :ref:`no_traffic_interval ` until the next attempt. * aws request signer: fix the AWS Request Signer extension to correctly normalize the path and query string to be signed according to AWS' guidelines, so that the hash on the server side matches. See `AWS SigV4 documentaion `_. * cluster: delete pools when they're idle to fix unbounded memory use when using PROXY protocol upstream with tcp_proxy. This behavior can be temporarily reverted by setting the ``envoy.reloadable_features.conn_pool_delete_when_idle`` runtime guard to false. * ext_authz: fixed skipping authentication when returning either a direct response or a redirect. This behavior can be temporarily reverted by setting the ``envoy.reloadable_features.http_ext_authz_do_not_skip_direct_response_and_redirect`` runtime guard to false. diff --git a/envoy/network/transport_socket.h b/envoy/network/transport_socket.h index f911c4d99dec0..312b569154443 100644 --- a/envoy/network/transport_socket.h +++ b/envoy/network/transport_socket.h @@ -249,6 +249,13 @@ class TransportSocketFactory { * negotiation. */ virtual bool supportsAlpn() const { return false; } + + /** + * @param a callback to be invoked when the secrets required by the created transport + * sockets are ready. Will be invoked immediately if no secrets are required or if they + * are already loaded. + */ + virtual void addReadyCb(std::function callback) PURE; }; using TransportSocketFactoryPtr = std::unique_ptr; diff --git a/envoy/upstream/upstream.h b/envoy/upstream/upstream.h index d8b73f39de23f..e5cb09ec173f7 100644 --- a/envoy/upstream/upstream.h +++ b/envoy/upstream/upstream.h @@ -108,6 +108,16 @@ class Host : virtual public HostDescription { Network::TransportSocketOptionsConstSharedPtr transport_socket_options, const envoy::config::core::v3::Metadata* metadata) const PURE; + /** + * Register a callback to be invoked when secrets are ready for the health + * checking transport socket that corresponds to the provided metadata. + * @param callback supplies the callback to be invoked. + * @param metadata supplies the metadata to be used for resolving transport socket matches. + */ + virtual void + addHealthCheckingReadyCb(std::function callback, + const envoy::config::core::v3::Metadata* metadata) const PURE; + /** * @return host specific gauges. */ diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index bad90ef9bbfa1..bf018f65810bd 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -35,6 +35,7 @@ class RawBufferSocketFactory : public TransportSocketFactory { createTransportSocket(TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override { return false; } + void addReadyCb(std::function callback) override { callback(); } }; } // namespace Network diff --git a/source/common/quic/quic_transport_socket_factory.h b/source/common/quic/quic_transport_socket_factory.h index 7c98e1ef29e69..380e4ec562736 100644 --- a/source/common/quic/quic_transport_socket_factory.h +++ b/source/common/quic/quic_transport_socket_factory.h @@ -51,6 +51,7 @@ class QuicTransportSocketFactoryBase : public Network::TransportSocketFactory, bool implementsSecureTransport() const override { return true; } bool usesProxyProtocolOptions() const override { return false; } bool supportsAlpn() const override { return true; } + void addReadyCb(std::function callback) override { callback(); } protected: virtual void onSecretUpdated() PURE; diff --git a/source/common/upstream/health_checker_base_impl.cc b/source/common/upstream/health_checker_base_impl.cc index b82058cf2aadd..e1993959e335f 100644 --- a/source/common/upstream/health_checker_base_impl.cc +++ b/source/common/upstream/health_checker_base_impl.cc @@ -411,6 +411,22 @@ void HealthCheckerImplBase::ActiveHealthCheckSession::onTimeoutBase() { handleFailure(envoy::data::core::v3::NETWORK_TIMEOUT); } +void HealthCheckerImplBase::ActiveHealthCheckSession::start() { + // Start health checks only after secrets are ready for the transport socket that health checks + // will be performed on. If health checks start immediately, they may fail with "network" errors + // due to TLS credentials not yet being loaded, which can result in long startup times. The + // callback needs to make sure this ActiveHealthCheckSession wasn't deleted before starting the + // health check loop in case it takes a while for the socket to become ready. + auto weak_self = weak_from_this(); + host_->addHealthCheckingReadyCb( + [weak_self] { + if (auto self = weak_self.lock()) { + self->onInitialInterval(); + } + }, + parent_.transportSocketMatchMetadata().get()); +} + void HealthCheckerImplBase::ActiveHealthCheckSession::onInitialInterval() { if (parent_.initial_jitter_.count() == 0) { onIntervalBase(); diff --git a/source/common/upstream/health_checker_base_impl.h b/source/common/upstream/health_checker_base_impl.h index 5081aac7351d2..322c982e16d03 100644 --- a/source/common/upstream/health_checker_base_impl.h +++ b/source/common/upstream/health_checker_base_impl.h @@ -73,12 +73,13 @@ class HealthCheckerImplBase : public HealthChecker, } protected: - class ActiveHealthCheckSession : public Event::DeferredDeletable { + class ActiveHealthCheckSession : public Event::DeferredDeletable, + public std::enable_shared_from_this { public: ~ActiveHealthCheckSession() override; HealthTransition setUnhealthy(envoy::data::core::v3::HealthCheckFailureType type); void onDeferredDeleteBase(); - void start() { onInitialInterval(); } + void start(); protected: ActiveHealthCheckSession(HealthCheckerImplBase& parent, HostSharedPtr host); diff --git a/source/common/upstream/upstream_impl.cc b/source/common/upstream/upstream_impl.cc index 44097ffe65de4..6c1d3232fbfad 100644 --- a/source/common/upstream/upstream_impl.cc +++ b/source/common/upstream/upstream_impl.cc @@ -335,6 +335,14 @@ Network::ClientConnectionPtr HostImpl::createConnection( return connection; } +void HostImpl::addHealthCheckingReadyCb(std::function callback, + const envoy::config::core::v3::Metadata* metadata) const { + Network::TransportSocketFactory& factory = + (metadata != nullptr) ? resolveTransportSocketFactory(healthCheckAddress(), metadata) + : transportSocketFactory(); + factory.addReadyCb(callback); +} + void HostImpl::weight(uint32_t new_weight) { weight_ = std::max(1U, new_weight); } std::vector HostsPerLocalityImpl::filter( diff --git a/source/common/upstream/upstream_impl.h b/source/common/upstream/upstream_impl.h index 6674082cd8d45..cbb3585f0dfb1 100644 --- a/source/common/upstream/upstream_impl.h +++ b/source/common/upstream/upstream_impl.h @@ -226,6 +226,9 @@ class HostImpl : public HostDescriptionImpl, Network::TransportSocketOptionsConstSharedPtr transport_socket_options, const envoy::config::core::v3::Metadata* metadata) const override; + void addHealthCheckingReadyCb(std::function callback, + const envoy::config::core::v3::Metadata* metadata) const override; + std::vector> gauges() const override { return stats().gauges(); diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index a93c60f15a847..5e6b5797163c0 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -137,6 +137,9 @@ class TsiSocketFactory : public Network::TransportSocketFactory { Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; + // TODO(mpuncel) only invoke callback() once secrets are ready. + void addReadyCb(std::function callback) override { callback(); }; + private: HandshakerFactory handshaker_factory_; HandshakeValidator handshake_validator_; diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h index 521804758417c..cb1196e57f884 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h @@ -49,6 +49,7 @@ class UpstreamProxyProtocolSocketFactory : public Network::TransportSocketFactor createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override { return true; } + void addReadyCb(std::function callback) override { callback(); }; private: Network::TransportSocketFactoryPtr transport_socket_factory_; diff --git a/source/extensions/transport_sockets/starttls/starttls_socket.h b/source/extensions/transport_sockets/starttls/starttls_socket.h index 20f24da3a5b69..97d22d3c5a4a1 100644 --- a/source/extensions/transport_sockets/starttls/starttls_socket.h +++ b/source/extensions/transport_sockets/starttls/starttls_socket.h @@ -77,6 +77,7 @@ class StartTlsSocketFactory : public Network::TransportSocketFactory, createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override { return false; } bool usesProxyProtocolOptions() const override { return false; } + void addReadyCb(std::function callback) override { callback(); } private: Network::TransportSocketFactoryPtr raw_socket_factory_; diff --git a/source/extensions/transport_sockets/tap/tap.h b/source/extensions/transport_sockets/tap/tap.h index 6e15f159e33c5..47f506698d81f 100644 --- a/source/extensions/transport_sockets/tap/tap.h +++ b/source/extensions/transport_sockets/tap/tap.h @@ -43,6 +43,9 @@ class TapSocketFactory : public Network::TransportSocketFactory, bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override; + // TODO(mpuncel) only invoke callback() once secrets are ready. + void addReadyCb(std::function callback) override { callback(); }; + private: Network::TransportSocketFactoryPtr transport_socket_factory_; }; diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index c2eef31132e3d..9675c40b350c4 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -379,13 +379,44 @@ bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } void ClientSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); + bool should_run_callbacks = false; { absl::WriterMutexLock l(&ssl_ctx_mu_); ssl_ctx_ = manager_.createSslClientContext(stats_scope_, *config_, ssl_ctx_); + if (ssl_ctx_) { + should_run_callbacks = true; + } + } + if (should_run_callbacks) { + { + absl::WriterMutexLock m(&secrets_ready_callbacks_mu_); + for (const auto& cb : secrets_ready_callbacks_) { + cb(); + } + secrets_ready_callbacks_.clear(); + } } stats_.ssl_context_update_by_sds_.inc(); } +void ClientSslSocketFactory::addReadyCb(std::function callback) { + bool immediately_run_callback = false; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + if (ssl_ctx_) { + immediately_run_callback = true; + } else { + { + absl::WriterMutexLock m(&secrets_ready_callbacks_mu_); + secrets_ready_callbacks_.push_back(callback); + } + } + } + if (immediately_run_callback) { + callback(); + } +} + ServerSslSocketFactory::ServerSslSocketFactory(Envoy::Ssl::ServerContextConfigPtr config, Envoy::Ssl::ContextManager& manager, Stats::Scope& stats_scope, @@ -425,13 +456,45 @@ bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } void ServerSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); + bool should_run_callbacks = false; { absl::WriterMutexLock l(&ssl_ctx_mu_); ssl_ctx_ = manager_.createSslServerContext(stats_scope_, *config_, server_names_, ssl_ctx_); + + if (ssl_ctx_) { + should_run_callbacks = true; + } + } + if (should_run_callbacks) { + { + absl::WriterMutexLock l(&secrets_ready_callbacks_mu_); + for (const auto& cb : secrets_ready_callbacks_) { + cb(); + } + secrets_ready_callbacks_.clear(); + } } stats_.ssl_context_update_by_sds_.inc(); } +void ServerSslSocketFactory::addReadyCb(std::function callback) { + bool immediately_run_callback = false; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + if (ssl_ctx_) { + immediately_run_callback = true; + } else { + { + absl::WriterMutexLock m(&secrets_ready_callbacks_mu_); + secrets_ready_callbacks_.push_back(callback); + } + } + } + if (immediately_run_callback) { + callback(); + } +} + } // namespace Tls } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 186bebbabc067..34e18d346c0c8 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -112,6 +112,8 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, bool usesProxyProtocolOptions() const override { return false; } bool supportsAlpn() const override { return true; } + void addReadyCb(std::function callback) override; + // Secret::SecretCallbacks void onAddOrUpdateSecret() override; @@ -126,6 +128,9 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, Envoy::Ssl::ClientContextConfigPtr config_; mutable absl::Mutex ssl_ctx_mu_; Envoy::Ssl::ClientContextSharedPtr ssl_ctx_ ABSL_GUARDED_BY(ssl_ctx_mu_); + mutable absl::Mutex secrets_ready_callbacks_mu_; + std::list> + secrets_ready_callbacks_ ABSL_GUARDED_BY(secrets_ready_callbacks_mu_); }; class ServerSslSocketFactory : public Network::TransportSocketFactory, @@ -141,6 +146,8 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override { return false; } + void addReadyCb(std::function callback) override; + // Secret::SecretCallbacks void onAddOrUpdateSecret() override; @@ -152,6 +159,9 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, const std::vector server_names_; mutable absl::Mutex ssl_ctx_mu_; Envoy::Ssl::ServerContextSharedPtr ssl_ctx_ ABSL_GUARDED_BY(ssl_ctx_mu_); + mutable absl::Mutex secrets_ready_callbacks_mu_; + std::list> + secrets_ready_callbacks_ ABSL_GUARDED_BY(secrets_ready_callbacks_mu_); }; } // namespace Tls diff --git a/test/common/upstream/health_checker_impl_test.cc b/test/common/upstream/health_checker_impl_test.cc index 973eb95cacd0a..96e852a5c2c66 100644 --- a/test/common/upstream/health_checker_impl_test.cc +++ b/test/common/upstream/health_checker_impl_test.cc @@ -1104,6 +1104,8 @@ TEST_F(HttpHealthCheckerImplTest, TlsOptions) { Network::TransportSocketFactoryPtr(socket_factory)); cluster_->info_->transport_socket_matcher_.reset(transport_socket_match); + EXPECT_CALL(*socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function callback) -> void { callback(); })); EXPECT_CALL(*socket_factory, createTransportSocket(ApplicationProtocolListEq("http1"))); allocHealthChecker(yaml); @@ -2582,13 +2584,19 @@ TEST_F(HttpHealthCheckerImplTest, TransportSocketMatchCriteria) { ALL_TRANSPORT_SOCKET_MATCH_STATS(POOL_COUNTER_PREFIX(stats_store, "test"))}; auto health_check_only_socket_factory = std::make_unique(); - // We expect resolve() to be called twice, once for endpoint socket matching (with no metadata in - // this test) and once for health check socket matching. In the latter we expect metadata that - // matches the above object. + // We expect resolve() to be called 3 times, once for endpoint socket matching (with no metadata + // in this test) and twice for health check socket matching (once for checking if secrets are + // ready on the transport socket, and again for actually getting the health check transport socket + // to create a connection). In the latter 2 calls, we expect metadata that matches the above + // object. EXPECT_CALL(*transport_socket_match, resolve(nullptr)); EXPECT_CALL(*transport_socket_match, resolve(MetadataEq(metadata))) - .WillOnce(Return(TransportSocketMatcher::MatchData( - *health_check_only_socket_factory, health_transport_socket_stats, "health_check_only"))); + .Times(2) + .WillRepeatedly(Return(TransportSocketMatcher::MatchData( + *health_check_only_socket_factory, health_transport_socket_stats, "health_check_only"))) + .RetiresOnSaturation(); + EXPECT_CALL(*health_check_only_socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function callback) -> void { callback(); })); // The health_check_only_socket_factory should be used to create a transport socket for the health // check connection. EXPECT_CALL(*health_check_only_socket_factory, createTransportSocket(_)); @@ -2604,7 +2612,11 @@ TEST_F(HttpHealthCheckerImplTest, TransportSocketMatchCriteria) { expectStreamCreate(0); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_, _)); health_checker_->start(); - EXPECT_EQ(health_transport_socket_stats.total_match_count_.value(), 1); + + // We expect 2 transport socket matches: one for when + // addHealthCheckingReadyCb() evaluates the match to register a callback on + // the socket, and once when the health checks are actually performed. + EXPECT_EQ(health_transport_socket_stats.total_match_count_.value(), 2); } TEST_F(HttpHealthCheckerImplTest, NoTransportSocketMatchCriteria) { @@ -2624,6 +2636,9 @@ TEST_F(HttpHealthCheckerImplTest, NoTransportSocketMatchCriteria) { )EOF"; auto default_socket_factory = std::make_unique(); + + EXPECT_CALL(*default_socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function callback) -> void { callback(); })); // The default_socket_factory should be used to create a transport socket for the health check // connection. EXPECT_CALL(*default_socket_factory, createTransportSocket(_)); diff --git a/test/common/upstream/transport_socket_matcher_test.cc b/test/common/upstream/transport_socket_matcher_test.cc index 2bedc726d5e86..8ab3b6254229a 100644 --- a/test/common/upstream/transport_socket_matcher_test.cc +++ b/test/common/upstream/transport_socket_matcher_test.cc @@ -33,6 +33,7 @@ class FakeTransportSocketFactory : public Network::TransportSocketFactory { MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(Network::TransportSocketPtr, createTransportSocket, (Network::TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, addReadyCb, (std::function)); FakeTransportSocketFactory(std::string id) : id_(std::move(id)) {} std::string id() const { return id_; } @@ -49,6 +50,7 @@ class FooTransportSocketFactory MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(Network::TransportSocketPtr, createTransportSocket, (Network::TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, addReadyCb, (std::function)); Network::TransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& proto, diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index db3ccd7c25202..fa66195c95c2e 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -59,6 +59,7 @@ using testing::ContainsRegex; using testing::DoAll; using testing::InSequence; using testing::Invoke; +using testing::MockFunction; using testing::NiceMock; using testing::Return; using testing::ReturnRef; @@ -4701,6 +4702,12 @@ TEST_P(SslSocketTest, DownstreamNotReadySslSocket) { ContextManagerImpl manager(time_system_); ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, stats_store, std::vector{}); + + // Add a secrets ready callback that should not be invoked. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + server_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + auto transport_socket = server_ssl_socket_factory.createTransportSocket(nullptr); EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); EXPECT_EQ(nullptr, transport_socket->ssl()); @@ -4737,6 +4744,12 @@ TEST_P(SslSocketTest, UpstreamNotReadySslSocket) { ContextManagerImpl manager(time_system_); ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, stats_store); + + // Add a secrets ready callback that should not be invoked. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + client_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + auto transport_socket = client_ssl_socket_factory.createTransportSocket(nullptr); EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); EXPECT_EQ(nullptr, transport_socket->ssl()); @@ -4749,6 +4762,183 @@ TEST_P(SslSocketTest, UpstreamNotReadySslSocket) { EXPECT_EQ("TLS error: Secret is not supplied by SDS", transport_socket->failureReason()); } +// Validate that secrets callbacks are invoked when secrets become ready. +TEST_P(SslSocketTest, ClientAddSecretsReadyCallback) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto client_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(client_cfg->tlsCertificates().empty()); + EXPECT_FALSE(client_cfg->isReady()); + + NiceMock context_manager; + ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), context_manager, + stats_store); + + // Add a secrets ready callback. It should not be invoked until onAddOrUpdateSecret() is called. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + client_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + + // Call onAddOrUpdateSecret, but return a null ssl_ctx. This should not invoke the callback. + EXPECT_CALL(context_manager, createSslClientContext(_, _, _)).WillOnce(Return(nullptr)); + client_ssl_socket_factory.onAddOrUpdateSecret(); + + EXPECT_CALL(mock_callback_, Call()); + Ssl::ClientContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslClientContext(_, _, _)).WillOnce(Return(mock_context)); + client_ssl_socket_factory.onAddOrUpdateSecret(); + + // Add another callback, it should be invoked immediately. + MockFunction second_callback_; + EXPECT_CALL(second_callback_, Call()); + client_ssl_socket_factory.addReadyCb(second_callback_.AsStdFunction()); +} + +// Validate that secrets callbacks are invoked when secrets become ready. +TEST_P(SslSocketTest, ServerAddSecretsReadyCallback) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto server_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(server_cfg->tlsCertificates().empty()); + EXPECT_FALSE(server_cfg->isReady()); + + NiceMock context_manager; + ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), context_manager, + stats_store, std::vector{}); + + // Add a secrets ready callback. It should not be invoked until onAddOrUpdateSecret() is called. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + server_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + + // Call onAddOrUpdateSecret, but return a null ssl_ctx. This should not invoke the callback. + EXPECT_CALL(context_manager, createSslServerContext(_, _, _, _)).WillOnce(Return(nullptr)); + server_ssl_socket_factory.onAddOrUpdateSecret(); + + // Now return a ssl context which should result in the callback being invoked. + EXPECT_CALL(mock_callback_, Call()); + Ssl::ServerContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslServerContext(_, _, _, _)).WillOnce(Return(mock_context)); + server_ssl_socket_factory.onAddOrUpdateSecret(); + + // Add another callback, it should be invoked immediately. + MockFunction second_callback_; + EXPECT_CALL(second_callback_, Call()); + server_ssl_socket_factory.addReadyCb(second_callback_.AsStdFunction()); +} + +// Tests adding a callback and adding secrets in parallel. This is intended to +// catch a race condition introduced in +// https://github.com/envoyproxy/envoy/pull/13516 where a callback would never +// be called due to a concurrency bug. +TEST_P(SslSocketTest, ServerAddSecretsReadyCallbackParallel) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto server_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(server_cfg->tlsCertificates().empty()); + EXPECT_FALSE(server_cfg->isReady()); + + NiceMock context_manager; + ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), context_manager, + stats_store, std::vector{}); + + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()); + Ssl::ServerContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslServerContext(_, _, _, _)).WillOnce(Return(mock_context)); + + // Add a callback in a thread to potentially tickle the concurrency bug. + std::thread t([&server_ssl_socket_factory, &mock_callback_]() { + server_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + }); + + server_ssl_socket_factory.onAddOrUpdateSecret(); + + t.join(); +} + +// Tests adding a callback and adding secrets in parallel. This is intended to +// catch a race condition introduced in +// https://github.com/envoyproxy/envoy/pull/13516 where a callback would never +// be called due to a concurrency bug. +TEST_P(SslSocketTest, ClientAddSecretsReadyCallbackParallel) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto client_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(client_cfg->tlsCertificates().empty()); + EXPECT_FALSE(client_cfg->isReady()); + + NiceMock context_manager; + ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), context_manager, + stats_store); + + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()); + Ssl::ClientContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslClientContext(_, _, _)).WillOnce(Return(mock_context)); + + // Add a callback in a thread to potentially tickle the concurrency bug. + std::thread t([&client_ssl_socket_factory, &mock_callback_]() { + client_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + }); + + client_ssl_socket_factory.onAddOrUpdateSecret(); + + t.join(); +} + TEST_P(SslSocketTest, TestTransportSocketCallback) { // Make MockTransportSocketCallbacks. Network::MockIoHandle io_handle; diff --git a/test/mocks/network/transport_socket.h b/test/mocks/network/transport_socket.h index ebb9e26c28904..9fb523322643a 100644 --- a/test/mocks/network/transport_socket.h +++ b/test/mocks/network/transport_socket.h @@ -41,6 +41,7 @@ class MockTransportSocketFactory : public TransportSocketFactory { MOCK_METHOD(bool, supportsAlpn, (), (const)); MOCK_METHOD(TransportSocketPtr, createTransportSocket, (TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, addReadyCb, (std::function)); }; } // namespace Network diff --git a/test/mocks/ssl/mocks.cc b/test/mocks/ssl/mocks.cc index 6e3ea24b32598..24e8a9f263b88 100644 --- a/test/mocks/ssl/mocks.cc +++ b/test/mocks/ssl/mocks.cc @@ -22,6 +22,9 @@ MockClientContextConfig::MockClientContextConfig() { } MockClientContextConfig::~MockClientContextConfig() = default; +MockServerContext::MockServerContext() = default; +MockServerContext::~MockServerContext() = default; + MockServerContextConfig::MockServerContextConfig() = default; MockServerContextConfig::~MockServerContextConfig() = default; diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 43ad275fce277..dbfe59fae27b9 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -107,6 +107,17 @@ class MockClientContextConfig : public ClientContextConfig { std::string test_{}; }; +class MockServerContext : public ServerContext { +public: + MockServerContext(); + ~MockServerContext() override; + + MOCK_METHOD(size_t, daysUntilFirstCertExpires, (), (const)); + MOCK_METHOD(absl::optional, secondsUntilFirstOcspResponseExpires, (), (const)); + MOCK_METHOD(CertificateDetailsPtr, getCaCertInformation, (), (const)); + MOCK_METHOD(std::vector, getCertChainInformation, (), (const)); +}; + class MockServerContextConfig : public ServerContextConfig { public: MockServerContextConfig(); diff --git a/test/mocks/upstream/host.h b/test/mocks/upstream/host.h index cd6cc8e21bc44..07a0db33623b9 100644 --- a/test/mocks/upstream/host.h +++ b/test/mocks/upstream/host.h @@ -198,6 +198,8 @@ class MockHost : public Host { MOCK_METHOD(void, priority, (uint32_t)); MOCK_METHOD(bool, warmed, (), (const)); MOCK_METHOD(MonotonicTime, creationTime, (), (const)); + MOCK_METHOD(void, addHealthCheckingReadyCb, + (std::function, const envoy::config::core::v3::Metadata*), (const)); testing::NiceMock cluster_; Network::TransportSocketFactoryPtr socket_factory_;