diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index a92d88f55707d..878e4c7f63125 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -58,6 +58,7 @@ Bug Fixes * access log: fix ``%UPSTREAM_CLUSTER%`` when used in http upstream access logs. Previously, it was always logging as an unset value. * access log: fix ``%UPSTREAM_CLUSTER%`` when used in http upstream access logs. Previously, it was always logging as an unset value. +* active health checks: health checks using a TLS transport socket and secrets delivered via :ref:`SDS ` will now wait until secrets are loaded before the first health check attempt. This should improve startup times by not having to wait for the :ref:`no_traffic_interval ` until the next attempt. * aws request signer: fix the AWS Request Signer extension to correctly normalize the path and query string to be signed according to AWS' guidelines, so that the hash on the server side matches. See `AWS SigV4 documentaion `_. * cluster: delete pools when they're idle to fix unbounded memory use when using PROXY protocol upstream with tcp_proxy. This behavior can be temporarily reverted by setting the ``envoy.reloadable_features.conn_pool_delete_when_idle`` runtime guard to false. * ext_authz: fix the ext_authz filter to correctly merge multiple same headers using the ',' as separator in the check request to the external authorization service. diff --git a/envoy/network/transport_socket.h b/envoy/network/transport_socket.h index f911c4d99dec0..8b6df61bf056a 100644 --- a/envoy/network/transport_socket.h +++ b/envoy/network/transport_socket.h @@ -249,6 +249,13 @@ class TransportSocketFactory { * negotiation. */ virtual bool supportsAlpn() const { return false; } + + /** + * @param callback supplies a callback to be invoked when the secrets required by the created + * transport sockets are ready. Will be invoked immediately if no secrets are required or if they + * are already loaded. + */ + virtual void addReadyCb(std::function callback) PURE; }; using TransportSocketFactoryPtr = std::unique_ptr; diff --git a/envoy/upstream/upstream.h b/envoy/upstream/upstream.h index 30bf5d8bb211c..e8fca7824ca4e 100644 --- a/envoy/upstream/upstream.h +++ b/envoy/upstream/upstream.h @@ -108,6 +108,16 @@ class Host : virtual public HostDescription { Network::TransportSocketOptionsConstSharedPtr transport_socket_options, const envoy::config::core::v3::Metadata* metadata) const PURE; + /** + * Register a callback to be invoked when secrets are ready for the health + * checking transport socket that corresponds to the provided metadata. + * @param callback supplies the callback to be invoked. + * @param metadata supplies the metadata to be used for resolving transport socket matches. + */ + virtual void + addHealthCheckingReadyCb(std::function callback, + const envoy::config::core::v3::Metadata* metadata) const PURE; + /** * @return host specific gauges. */ diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index bad90ef9bbfa1..bf018f65810bd 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -35,6 +35,7 @@ class RawBufferSocketFactory : public TransportSocketFactory { createTransportSocket(TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override { return false; } + void addReadyCb(std::function callback) override { callback(); } }; } // namespace Network diff --git a/source/common/quic/quic_transport_socket_factory.h b/source/common/quic/quic_transport_socket_factory.h index 7c98e1ef29e69..380e4ec562736 100644 --- a/source/common/quic/quic_transport_socket_factory.h +++ b/source/common/quic/quic_transport_socket_factory.h @@ -51,6 +51,7 @@ class QuicTransportSocketFactoryBase : public Network::TransportSocketFactory, bool implementsSecureTransport() const override { return true; } bool usesProxyProtocolOptions() const override { return false; } bool supportsAlpn() const override { return true; } + void addReadyCb(std::function callback) override { callback(); } protected: virtual void onSecretUpdated() PURE; diff --git a/source/common/upstream/health_checker_base_impl.cc b/source/common/upstream/health_checker_base_impl.cc index b82058cf2aadd..ce1be07f07a36 100644 --- a/source/common/upstream/health_checker_base_impl.cc +++ b/source/common/upstream/health_checker_base_impl.cc @@ -229,7 +229,8 @@ HealthCheckerImplBase::ActiveHealthCheckSession::ActiveHealthCheckSession( HealthCheckerImplBase& parent, HostSharedPtr host) : host_(host), parent_(parent), interval_timer_(parent.dispatcher_.createTimer([this]() -> void { onIntervalBase(); })), - timeout_timer_(parent.dispatcher_.createTimer([this]() -> void { onTimeoutBase(); })) { + timeout_timer_(parent.dispatcher_.createTimer([this]() -> void { onTimeoutBase(); })), + lifetime_guard_(std::make_shared(1)) { if (!host->healthFlagGet(Host::HealthFlag::FAILED_ACTIVE_HC)) { parent.incHealthy(); @@ -411,6 +412,22 @@ void HealthCheckerImplBase::ActiveHealthCheckSession::onTimeoutBase() { handleFailure(envoy::data::core::v3::NETWORK_TIMEOUT); } +void HealthCheckerImplBase::ActiveHealthCheckSession::start() { + // Start health checks only after secrets are ready for the transport socket that health checks + // will be performed on. If health checks start immediately, they may fail with "network" errors + // due to TLS credentials not yet being loaded, which can result in long startup times. The + // callback needs to make sure this ActiveHealthCheckSession wasn't deleted before starting the + // health check loop in case it takes a while for the socket to become ready. + std::weak_ptr lifetime_guard = lifetime_guard_; + host_->addHealthCheckingReadyCb( + [this, lifetime_guard] { + if (lifetime_guard.lock()) { + onInitialInterval(); + } + }, + parent_.transportSocketMatchMetadata().get()); +} + void HealthCheckerImplBase::ActiveHealthCheckSession::onInitialInterval() { if (parent_.initial_jitter_.count() == 0) { onIntervalBase(); diff --git a/source/common/upstream/health_checker_base_impl.h b/source/common/upstream/health_checker_base_impl.h index 5081aac7351d2..fdbe074b97b65 100644 --- a/source/common/upstream/health_checker_base_impl.h +++ b/source/common/upstream/health_checker_base_impl.h @@ -78,7 +78,7 @@ class HealthCheckerImplBase : public HealthChecker, ~ActiveHealthCheckSession() override; HealthTransition setUnhealthy(envoy::data::core::v3::HealthCheckFailureType type); void onDeferredDeleteBase(); - void start() { onInitialInterval(); } + void start(); protected: ActiveHealthCheckSession(HealthCheckerImplBase& parent, HostSharedPtr host); @@ -107,6 +107,10 @@ class HealthCheckerImplBase : public HealthChecker, uint32_t num_unhealthy_{}; uint32_t num_healthy_{}; bool first_check_{true}; + + // lifetime_guard_ is used to ensure health checks are not started via a callback after this + // ActiveHealthCheckSession has been deleted. + std::shared_ptr lifetime_guard_{}; }; using ActiveHealthCheckSessionPtr = std::unique_ptr; diff --git a/source/common/upstream/upstream_impl.cc b/source/common/upstream/upstream_impl.cc index eca68f00929f0..db8edc0b7ce7a 100644 --- a/source/common/upstream/upstream_impl.cc +++ b/source/common/upstream/upstream_impl.cc @@ -335,6 +335,14 @@ Network::ClientConnectionPtr HostImpl::createConnection( return connection; } +void HostImpl::addHealthCheckingReadyCb(std::function callback, + const envoy::config::core::v3::Metadata* metadata) const { + Network::TransportSocketFactory& factory = + (metadata != nullptr) ? resolveTransportSocketFactory(healthCheckAddress(), metadata) + : transportSocketFactory(); + factory.addReadyCb(callback); +} + void HostImpl::weight(uint32_t new_weight) { weight_ = std::max(1U, new_weight); } std::vector HostsPerLocalityImpl::filter( diff --git a/source/common/upstream/upstream_impl.h b/source/common/upstream/upstream_impl.h index 91b9c2133a887..93188ef3030cb 100644 --- a/source/common/upstream/upstream_impl.h +++ b/source/common/upstream/upstream_impl.h @@ -226,6 +226,9 @@ class HostImpl : public HostDescriptionImpl, Network::TransportSocketOptionsConstSharedPtr transport_socket_options, const envoy::config::core::v3::Metadata* metadata) const override; + void addHealthCheckingReadyCb(std::function callback, + const envoy::config::core::v3::Metadata* metadata) const override; + std::vector> gauges() const override { return stats().gauges(); diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index a93c60f15a847..980a5cf6eeb1f 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -137,6 +137,8 @@ class TsiSocketFactory : public Network::TransportSocketFactory { Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; + void addReadyCb(std::function callback) override { callback(); }; + private: HandshakerFactory handshaker_factory_; HandshakeValidator handshake_validator_; diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h index 521804758417c..cb1196e57f884 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h @@ -49,6 +49,7 @@ class UpstreamProxyProtocolSocketFactory : public Network::TransportSocketFactor createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override { return true; } + void addReadyCb(std::function callback) override { callback(); }; private: Network::TransportSocketFactoryPtr transport_socket_factory_; diff --git a/source/extensions/transport_sockets/starttls/starttls_socket.h b/source/extensions/transport_sockets/starttls/starttls_socket.h index 20f24da3a5b69..c5860565aa7cc 100644 --- a/source/extensions/transport_sockets/starttls/starttls_socket.h +++ b/source/extensions/transport_sockets/starttls/starttls_socket.h @@ -78,6 +78,9 @@ class StartTlsSocketFactory : public Network::TransportSocketFactory, bool implementsSecureTransport() const override { return false; } bool usesProxyProtocolOptions() const override { return false; } + // TODO(mpuncel) only invoke callback() once secrets are ready. + void addReadyCb(std::function callback) override { callback(); } + private: Network::TransportSocketFactoryPtr raw_socket_factory_; Network::TransportSocketFactoryPtr tls_socket_factory_; diff --git a/source/extensions/transport_sockets/tap/tap.h b/source/extensions/transport_sockets/tap/tap.h index 6e15f159e33c5..47f506698d81f 100644 --- a/source/extensions/transport_sockets/tap/tap.h +++ b/source/extensions/transport_sockets/tap/tap.h @@ -43,6 +43,9 @@ class TapSocketFactory : public Network::TransportSocketFactory, bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override; + // TODO(mpuncel) only invoke callback() once secrets are ready. + void addReadyCb(std::function callback) override { callback(); }; + private: Network::TransportSocketFactoryPtr transport_socket_factory_; }; diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index c2eef31132e3d..de32eed7ec578 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -379,13 +379,41 @@ bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } void ClientSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); + bool should_run_callbacks = false; { absl::WriterMutexLock l(&ssl_ctx_mu_); ssl_ctx_ = manager_.createSslClientContext(stats_scope_, *config_, ssl_ctx_); + if (ssl_ctx_) { + should_run_callbacks = true; + } + } + + if (should_run_callbacks) { + absl::WriterMutexLock m(&secrets_ready_callbacks_mu_); + for (const auto& cb : secrets_ready_callbacks_) { + cb(); + } + secrets_ready_callbacks_.clear(); } stats_.ssl_context_update_by_sds_.inc(); } +void ClientSslSocketFactory::addReadyCb(std::function callback) { + bool immediately_run_callback = false; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + if (ssl_ctx_) { + immediately_run_callback = true; + } else { + absl::WriterMutexLock m(&secrets_ready_callbacks_mu_); + secrets_ready_callbacks_.push_back(callback); + } + } + if (immediately_run_callback) { + callback(); + } +} + ServerSslSocketFactory::ServerSslSocketFactory(Envoy::Ssl::ServerContextConfigPtr config, Envoy::Ssl::ContextManager& manager, Stats::Scope& stats_scope, @@ -425,13 +453,42 @@ bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } void ServerSslSocketFactory::onAddOrUpdateSecret() { ENVOY_LOG(debug, "Secret is updated."); + bool should_run_callbacks = false; { absl::WriterMutexLock l(&ssl_ctx_mu_); ssl_ctx_ = manager_.createSslServerContext(stats_scope_, *config_, server_names_, ssl_ctx_); + + if (ssl_ctx_) { + should_run_callbacks = true; + } + } + + if (should_run_callbacks) { + absl::WriterMutexLock l(&secrets_ready_callbacks_mu_); + for (const auto& cb : secrets_ready_callbacks_) { + cb(); + } + secrets_ready_callbacks_.clear(); } stats_.ssl_context_update_by_sds_.inc(); } +void ServerSslSocketFactory::addReadyCb(std::function callback) { + bool immediately_run_callback = false; + { + absl::ReaderMutexLock l(&ssl_ctx_mu_); + if (ssl_ctx_) { + immediately_run_callback = true; + } else { + absl::WriterMutexLock m(&secrets_ready_callbacks_mu_); + secrets_ready_callbacks_.push_back(callback); + } + } + if (immediately_run_callback) { + callback(); + } +} + } // namespace Tls } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 186bebbabc067..34e18d346c0c8 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -112,6 +112,8 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, bool usesProxyProtocolOptions() const override { return false; } bool supportsAlpn() const override { return true; } + void addReadyCb(std::function callback) override; + // Secret::SecretCallbacks void onAddOrUpdateSecret() override; @@ -126,6 +128,9 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, Envoy::Ssl::ClientContextConfigPtr config_; mutable absl::Mutex ssl_ctx_mu_; Envoy::Ssl::ClientContextSharedPtr ssl_ctx_ ABSL_GUARDED_BY(ssl_ctx_mu_); + mutable absl::Mutex secrets_ready_callbacks_mu_; + std::list> + secrets_ready_callbacks_ ABSL_GUARDED_BY(secrets_ready_callbacks_mu_); }; class ServerSslSocketFactory : public Network::TransportSocketFactory, @@ -141,6 +146,8 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, bool implementsSecureTransport() const override; bool usesProxyProtocolOptions() const override { return false; } + void addReadyCb(std::function callback) override; + // Secret::SecretCallbacks void onAddOrUpdateSecret() override; @@ -152,6 +159,9 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, const std::vector server_names_; mutable absl::Mutex ssl_ctx_mu_; Envoy::Ssl::ServerContextSharedPtr ssl_ctx_ ABSL_GUARDED_BY(ssl_ctx_mu_); + mutable absl::Mutex secrets_ready_callbacks_mu_; + std::list> + secrets_ready_callbacks_ ABSL_GUARDED_BY(secrets_ready_callbacks_mu_); }; } // namespace Tls diff --git a/test/common/upstream/health_checker_impl_test.cc b/test/common/upstream/health_checker_impl_test.cc index e41010cb3172e..e1cac269c9f81 100644 --- a/test/common/upstream/health_checker_impl_test.cc +++ b/test/common/upstream/health_checker_impl_test.cc @@ -1104,6 +1104,8 @@ TEST_F(HttpHealthCheckerImplTest, TlsOptions) { Network::TransportSocketFactoryPtr(socket_factory)); cluster_->info_->transport_socket_matcher_.reset(transport_socket_match); + EXPECT_CALL(*socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function callback) -> void { callback(); })); EXPECT_CALL(*socket_factory, createTransportSocket(ApplicationProtocolListEq("http1"))); allocHealthChecker(yaml); @@ -2582,13 +2584,19 @@ TEST_F(HttpHealthCheckerImplTest, TransportSocketMatchCriteria) { ALL_TRANSPORT_SOCKET_MATCH_STATS(POOL_COUNTER_PREFIX(stats_store, "test"))}; auto health_check_only_socket_factory = std::make_unique(); - // We expect resolve() to be called twice, once for endpoint socket matching (with no metadata in - // this test) and once for health check socket matching. In the latter we expect metadata that - // matches the above object. + // We expect resolve() to be called 3 times, once for endpoint socket matching (with no metadata + // in this test) and twice for health check socket matching (once for checking if secrets are + // ready on the transport socket, and again for actually getting the health check transport socket + // to create a connection). In the latter 2 calls, we expect metadata that matches the above + // object. EXPECT_CALL(*transport_socket_match, resolve(nullptr)); EXPECT_CALL(*transport_socket_match, resolve(MetadataEq(metadata))) - .WillOnce(Return(TransportSocketMatcher::MatchData( - *health_check_only_socket_factory, health_transport_socket_stats, "health_check_only"))); + .Times(2) + .WillRepeatedly(Return(TransportSocketMatcher::MatchData( + *health_check_only_socket_factory, health_transport_socket_stats, "health_check_only"))) + .RetiresOnSaturation(); + EXPECT_CALL(*health_check_only_socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function callback) -> void { callback(); })); // The health_check_only_socket_factory should be used to create a transport socket for the health // check connection. EXPECT_CALL(*health_check_only_socket_factory, createTransportSocket(_)); @@ -2604,7 +2612,11 @@ TEST_F(HttpHealthCheckerImplTest, TransportSocketMatchCriteria) { expectStreamCreate(0); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_, _)); health_checker_->start(); - EXPECT_EQ(health_transport_socket_stats.total_match_count_.value(), 1); + + // We expect 2 transport socket matches: one for when + // addHealthCheckingReadyCb() evaluates the match to register a callback on + // the socket, and once when the health checks are actually performed. + EXPECT_EQ(health_transport_socket_stats.total_match_count_.value(), 2); } TEST_F(HttpHealthCheckerImplTest, NoTransportSocketMatchCriteria) { @@ -2624,6 +2636,9 @@ TEST_F(HttpHealthCheckerImplTest, NoTransportSocketMatchCriteria) { )EOF"; auto default_socket_factory = std::make_unique(); + + EXPECT_CALL(*default_socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function callback) -> void { callback(); })); // The default_socket_factory should be used to create a transport socket for the health check // connection. EXPECT_CALL(*default_socket_factory, createTransportSocket(_)); @@ -3116,6 +3131,36 @@ TEST_F(HttpHealthCheckerImplTest, ServiceNameMismatch) { cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->health()); } +// Test that the underlying transport socket becoming ready after the health check session +// is destroyed doesn't attempt to start health checks. +TEST_F(HttpHealthCheckerImplTest, NoHealthCheckAfterSessionDestroyed) { + auto default_socket_factory = std::make_unique(); + std::function callback = nullptr; + + // Capture the callback to addReadyCb. We will invoke it later once the session is destroyed. + EXPECT_CALL(*default_socket_factory, addReadyCb(_)) + .WillOnce(Invoke([&](std::function cb) -> void { callback = cb; })); + + auto transport_socket_match = + std::make_unique(std::move(default_socket_factory)); + EXPECT_CALL(*transport_socket_match, resolve(nullptr)); + cluster_->info_->transport_socket_matcher_ = std::move(transport_socket_match); + + setupNoServiceValidationHC(); + + cluster_->prioritySet().getMockHostSet(0)->hosts_ = { + makeTestHost(cluster_->info_, "tcp://127.0.0.1:80", simTime())}; + + health_checker_->start(); + + // Destroy the health checker object. + health_checker_.reset(); + + // Call the callback that would have started health checks had the health checker object not + // destroyed. This should not segfault or otherwise attempt to start health checking. + callback(); +} + TEST_F(ProdHttpHealthCheckerTest, ProdHttpHealthCheckerH2HealthChecking) { setupNoServiceValidationHCWithHttp2(); EXPECT_EQ(Http::CodecType::HTTP2, diff --git a/test/common/upstream/transport_socket_matcher_test.cc b/test/common/upstream/transport_socket_matcher_test.cc index 2bedc726d5e86..8ab3b6254229a 100644 --- a/test/common/upstream/transport_socket_matcher_test.cc +++ b/test/common/upstream/transport_socket_matcher_test.cc @@ -33,6 +33,7 @@ class FakeTransportSocketFactory : public Network::TransportSocketFactory { MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(Network::TransportSocketPtr, createTransportSocket, (Network::TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, addReadyCb, (std::function)); FakeTransportSocketFactory(std::string id) : id_(std::move(id)) {} std::string id() const { return id_; } @@ -49,6 +50,7 @@ class FooTransportSocketFactory MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(Network::TransportSocketPtr, createTransportSocket, (Network::TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, addReadyCb, (std::function)); Network::TransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& proto, diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index 88da907bd834d..7f252059f9165 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -59,6 +59,7 @@ using testing::ContainsRegex; using testing::DoAll; using testing::InSequence; using testing::Invoke; +using testing::MockFunction; using testing::NiceMock; using testing::Return; using testing::ReturnRef; @@ -4721,6 +4722,12 @@ TEST_P(SslSocketTest, DownstreamNotReadySslSocket) { ContextManagerImpl manager(time_system_); ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, stats_store, std::vector{}); + + // Add a secrets ready callback that should not be invoked. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + server_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + auto transport_socket = server_ssl_socket_factory.createTransportSocket(nullptr); EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); EXPECT_EQ(nullptr, transport_socket->ssl()); @@ -4757,6 +4764,12 @@ TEST_P(SslSocketTest, UpstreamNotReadySslSocket) { ContextManagerImpl manager(time_system_); ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, stats_store); + + // Add a secrets ready callback that should not be invoked. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + client_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + auto transport_socket = client_ssl_socket_factory.createTransportSocket(nullptr); EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); EXPECT_EQ(nullptr, transport_socket->ssl()); @@ -4769,6 +4782,183 @@ TEST_P(SslSocketTest, UpstreamNotReadySslSocket) { EXPECT_EQ("TLS error: Secret is not supplied by SDS", transport_socket->failureReason()); } +// Validate that secrets callbacks are invoked when secrets become ready. +TEST_P(SslSocketTest, ClientAddSecretsReadyCallback) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto client_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(client_cfg->tlsCertificates().empty()); + EXPECT_FALSE(client_cfg->isReady()); + + NiceMock context_manager; + ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), context_manager, + stats_store); + + // Add a secrets ready callback. It should not be invoked until onAddOrUpdateSecret() is called. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + client_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + + // Call onAddOrUpdateSecret, but return a null ssl_ctx. This should not invoke the callback. + EXPECT_CALL(context_manager, createSslClientContext(_, _, _)).WillOnce(Return(nullptr)); + client_ssl_socket_factory.onAddOrUpdateSecret(); + + EXPECT_CALL(mock_callback_, Call()); + Ssl::ClientContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslClientContext(_, _, _)).WillOnce(Return(mock_context)); + client_ssl_socket_factory.onAddOrUpdateSecret(); + + // Add another callback, it should be invoked immediately. + MockFunction second_callback_; + EXPECT_CALL(second_callback_, Call()); + client_ssl_socket_factory.addReadyCb(second_callback_.AsStdFunction()); +} + +// Validate that secrets callbacks are invoked when secrets become ready. +TEST_P(SslSocketTest, ServerAddSecretsReadyCallback) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto server_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(server_cfg->tlsCertificates().empty()); + EXPECT_FALSE(server_cfg->isReady()); + + NiceMock context_manager; + ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), context_manager, + stats_store, std::vector{}); + + // Add a secrets ready callback. It should not be invoked until onAddOrUpdateSecret() is called. + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()).Times(0); + server_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + + // Call onAddOrUpdateSecret, but return a null ssl_ctx. This should not invoke the callback. + EXPECT_CALL(context_manager, createSslServerContext(_, _, _, _)).WillOnce(Return(nullptr)); + server_ssl_socket_factory.onAddOrUpdateSecret(); + + // Now return a ssl context which should result in the callback being invoked. + EXPECT_CALL(mock_callback_, Call()); + Ssl::ServerContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslServerContext(_, _, _, _)).WillOnce(Return(mock_context)); + server_ssl_socket_factory.onAddOrUpdateSecret(); + + // Add another callback, it should be invoked immediately. + MockFunction second_callback_; + EXPECT_CALL(second_callback_, Call()); + server_ssl_socket_factory.addReadyCb(second_callback_.AsStdFunction()); +} + +// Tests adding a callback and adding secrets in parallel. This is intended to +// catch a race condition introduced in +// https://github.com/envoyproxy/envoy/pull/13516 where a callback would never +// be called due to a concurrency bug. +TEST_P(SslSocketTest, ServerAddSecretsReadyCallbackParallel) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto server_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(server_cfg->tlsCertificates().empty()); + EXPECT_FALSE(server_cfg->isReady()); + + NiceMock context_manager; + ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), context_manager, + stats_store, std::vector{}); + + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()); + Ssl::ServerContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslServerContext(_, _, _, _)).WillOnce(Return(mock_context)); + + // Add a callback in a thread to potentially tickle the concurrency bug. + std::thread t([&server_ssl_socket_factory, &mock_callback_]() { + server_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + }); + + server_ssl_socket_factory.onAddOrUpdateSecret(); + + t.join(); +} + +// Tests adding a callback and adding secrets in parallel. This is intended to +// catch a race condition introduced in +// https://github.com/envoyproxy/envoy/pull/13516 where a callback would never +// be called due to a concurrency bug. +TEST_P(SslSocketTest, ClientAddSecretsReadyCallbackParallel) { + Stats::TestUtil::TestStore stats_store; + NiceMock local_info; + testing::NiceMock factory_context; + NiceMock init_manager; + NiceMock dispatcher; + EXPECT_CALL(factory_context, localInfo()).WillOnce(ReturnRef(local_info)); + EXPECT_CALL(factory_context, stats()).WillOnce(ReturnRef(stats_store)); + EXPECT_CALL(factory_context, initManager()).WillRepeatedly(ReturnRef(init_manager)); + EXPECT_CALL(factory_context, dispatcher()).WillRepeatedly(ReturnRef(dispatcher)); + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; + auto sds_secret_configs = + tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); + sds_secret_configs->set_name("abc.com"); + sds_secret_configs->mutable_sds_config(); + auto client_cfg = std::make_unique(tls_context, factory_context); + EXPECT_TRUE(client_cfg->tlsCertificates().empty()); + EXPECT_FALSE(client_cfg->isReady()); + + NiceMock context_manager; + ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), context_manager, + stats_store); + + MockFunction mock_callback_; + EXPECT_CALL(mock_callback_, Call()); + Ssl::ClientContextSharedPtr mock_context = std::make_shared(); + EXPECT_CALL(context_manager, createSslClientContext(_, _, _)).WillOnce(Return(mock_context)); + + // Add a callback in a thread to potentially tickle the concurrency bug. + std::thread t([&client_ssl_socket_factory, &mock_callback_]() { + client_ssl_socket_factory.addReadyCb(mock_callback_.AsStdFunction()); + }); + + client_ssl_socket_factory.onAddOrUpdateSecret(); + + t.join(); +} + TEST_P(SslSocketTest, TestTransportSocketCallback) { // Make MockTransportSocketCallbacks. Network::MockIoHandle io_handle; diff --git a/test/mocks/network/transport_socket.h b/test/mocks/network/transport_socket.h index ebb9e26c28904..9fb523322643a 100644 --- a/test/mocks/network/transport_socket.h +++ b/test/mocks/network/transport_socket.h @@ -41,6 +41,7 @@ class MockTransportSocketFactory : public TransportSocketFactory { MOCK_METHOD(bool, supportsAlpn, (), (const)); MOCK_METHOD(TransportSocketPtr, createTransportSocket, (TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, addReadyCb, (std::function)); }; } // namespace Network diff --git a/test/mocks/ssl/mocks.cc b/test/mocks/ssl/mocks.cc index 6e3ea24b32598..24e8a9f263b88 100644 --- a/test/mocks/ssl/mocks.cc +++ b/test/mocks/ssl/mocks.cc @@ -22,6 +22,9 @@ MockClientContextConfig::MockClientContextConfig() { } MockClientContextConfig::~MockClientContextConfig() = default; +MockServerContext::MockServerContext() = default; +MockServerContext::~MockServerContext() = default; + MockServerContextConfig::MockServerContextConfig() = default; MockServerContextConfig::~MockServerContextConfig() = default; diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 43ad275fce277..dbfe59fae27b9 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -107,6 +107,17 @@ class MockClientContextConfig : public ClientContextConfig { std::string test_{}; }; +class MockServerContext : public ServerContext { +public: + MockServerContext(); + ~MockServerContext() override; + + MOCK_METHOD(size_t, daysUntilFirstCertExpires, (), (const)); + MOCK_METHOD(absl::optional, secondsUntilFirstOcspResponseExpires, (), (const)); + MOCK_METHOD(CertificateDetailsPtr, getCaCertInformation, (), (const)); + MOCK_METHOD(std::vector, getCertChainInformation, (), (const)); +}; + class MockServerContextConfig : public ServerContextConfig { public: MockServerContextConfig(); diff --git a/test/mocks/upstream/host.h b/test/mocks/upstream/host.h index cd6cc8e21bc44..07a0db33623b9 100644 --- a/test/mocks/upstream/host.h +++ b/test/mocks/upstream/host.h @@ -198,6 +198,8 @@ class MockHost : public Host { MOCK_METHOD(void, priority, (uint32_t)); MOCK_METHOD(bool, warmed, (), (const)); MOCK_METHOD(MonotonicTime, creationTime, (), (const)); + MOCK_METHOD(void, addHealthCheckingReadyCb, + (std::function, const envoy::config::core::v3::Metadata*), (const)); testing::NiceMock cluster_; Network::TransportSocketFactoryPtr socket_factory_;