diff --git a/include/envoy/ssl/context.h b/include/envoy/ssl/context.h index 5af2fa804fb8a..b3d63bfe45b78 100644 --- a/include/envoy/ssl/context.h +++ b/include/envoy/ssl/context.h @@ -32,12 +32,13 @@ class Context { */ virtual std::string getCertChainInformation() const PURE; }; +typedef std::shared_ptr ContextSharedPtr; class ClientContext : public virtual Context {}; -typedef std::unique_ptr ClientContextPtr; +typedef std::shared_ptr ClientContextSharedPtr; class ServerContext : public virtual Context {}; -typedef std::unique_ptr ServerContextPtr; +typedef std::shared_ptr ServerContextSharedPtr; } // namespace Ssl } // namespace Envoy diff --git a/include/envoy/ssl/context_manager.h b/include/envoy/ssl/context_manager.h index 7489800c99caf..ea63ab9981f05 100644 --- a/include/envoy/ssl/context_manager.h +++ b/include/envoy/ssl/context_manager.h @@ -19,13 +19,13 @@ class ContextManager { /** * Builds a ClientContext from a ClientContextConfig. */ - virtual ClientContextPtr createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) PURE; + virtual ClientContextSharedPtr createSslClientContext(Stats::Scope& scope, + const ClientContextConfig& config) PURE; /** * Builds a ServerContext from a ServerContextConfig. */ - virtual ServerContextPtr + virtual ServerContextSharedPtr createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names) PURE; diff --git a/source/common/ssl/context_impl.cc b/source/common/ssl/context_impl.cc index f380e1de2532f..d2960feebab8a 100644 --- a/source/common/ssl/context_impl.cc +++ b/source/common/ssl/context_impl.cc @@ -29,10 +29,8 @@ int ContextImpl::sslContextIndex() { }()); } -ContextImpl::ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ContextConfig& config) - : parent_(parent), ctx_(SSL_CTX_new(TLS_method())), scope_(scope), - stats_(generateStats(scope)) { +ContextImpl::ContextImpl(Stats::Scope& scope, const ContextConfig& config) + : ctx_(SSL_CTX_new(TLS_method())), scope_(scope), stats_(generateStats(scope)) { RELEASE_ASSERT(ctx_, ""); int rc = SSL_CTX_set_ex_data(ctx_.get(), sslContextIndex(), this); @@ -461,9 +459,8 @@ std::string ContextImpl::getCertChainInformation() const { getDaysUntilExpiration(cert_chain_.get())); } -ClientContextImpl::ClientContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ClientContextConfig& config) - : ContextImpl(parent, scope, config), server_name_indication_(config.serverNameIndication()), +ClientContextImpl::ClientContextImpl(Stats::Scope& scope, const ClientContextConfig& config) + : ContextImpl(scope, config), server_name_indication_(config.serverNameIndication()), allow_renegotiation_(config.allowRenegotiation()) { if (!parsed_alpn_protocols_.empty()) { int rc = SSL_CTX_set_alpn_protos(ctx_.get(), &parsed_alpn_protocols_[0], @@ -487,11 +484,10 @@ bssl::UniquePtr ClientContextImpl::newSsl() const { return ssl_con; } -ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ServerContextConfig& config, +ServerContextImpl::ServerContextImpl(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names, Runtime::Loader& runtime) - : ContextImpl(parent, scope, config), runtime_(runtime), + : ContextImpl(scope, config), runtime_(runtime), session_ticket_keys_(config.sessionTicketKeys()) { if (config.certChain().empty()) { throw EnvoyException("Server TlsCertificates must have a certificate specified"); diff --git a/source/common/ssl/context_impl.h b/source/common/ssl/context_impl.h index 9e28cdf18832e..698769d5b9201 100644 --- a/source/common/ssl/context_impl.h +++ b/source/common/ssl/context_impl.h @@ -75,8 +75,7 @@ class ContextImpl : public virtual Context { std::string getCertChainInformation() const override; protected: - ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, const ContextConfig& config); - ~ContextImpl() { parent_.releaseContext(this); } + ContextImpl(Stats::Scope& scope, const ContextConfig& config); /** * The global SSL-library index used for storing a pointer to the context @@ -123,7 +122,6 @@ class ContextImpl : public virtual Context { std::string getCaFileName() const { return ca_file_path_; }; std::string getCertChainFileName() const { return cert_chain_file_path_; }; - ContextManagerImpl& parent_; bssl::UniquePtr ctx_; bool verify_trusted_ca_{false}; std::vector verify_subject_alt_name_list_; @@ -138,10 +136,11 @@ class ContextImpl : public virtual Context { std::string cert_chain_file_path_; }; +typedef std::shared_ptr ContextImplSharedPtr; + class ClientContextImpl : public ContextImpl, public ClientContext { public: - ClientContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ClientContextConfig& config); + ClientContextImpl(Stats::Scope& scope, const ClientContextConfig& config); bssl::UniquePtr newSsl() const override; @@ -152,9 +151,8 @@ class ClientContextImpl : public ContextImpl, public ClientContext { class ServerContextImpl : public ContextImpl, public ServerContext { public: - ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ServerContextConfig& config, const std::vector& server_names, - Runtime::Loader& runtime); + ServerContextImpl(Stats::Scope& scope, const ServerContextConfig& config, + const std::vector& server_names, Runtime::Loader& runtime); private: int alpnSelectCallback(const unsigned char** out, unsigned char* outlen, const unsigned char* in, diff --git a/source/common/ssl/context_manager_impl.cc b/source/common/ssl/context_manager_impl.cc index 1c54c2f018656..b82fd96151af6 100644 --- a/source/common/ssl/context_manager_impl.cc +++ b/source/common/ssl/context_manager_impl.cc @@ -9,47 +9,50 @@ namespace Envoy { namespace Ssl { -ContextManagerImpl::~ContextManagerImpl() { ASSERT(contexts_.empty()); } - -void ContextManagerImpl::releaseContext(Context* context) { - std::unique_lock lock(contexts_lock_); +ContextManagerImpl::~ContextManagerImpl() { + removeEmptyContexts(); + ASSERT(contexts_.empty()); +} - // context may not be found, in the case that a subclass of Context throws - // in it's constructor. In that case the context did not get added, but - // the destructor of Context will run and call releaseContext(). - contexts_.remove(context); +void ContextManagerImpl::removeEmptyContexts() { + contexts_.remove_if([](const std::weak_ptr& n) { return n.expired(); }); } -ClientContextPtr ContextManagerImpl::createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) { - ClientContextPtr context(new ClientContextImpl(*this, scope, config)); - std::unique_lock lock(contexts_lock_); - contexts_.emplace_back(context.get()); +ClientContextSharedPtr +ContextManagerImpl::createSslClientContext(Stats::Scope& scope, const ClientContextConfig& config) { + ClientContextSharedPtr context = std::make_shared(scope, config); + removeEmptyContexts(); + contexts_.emplace_back(context); return context; } -ServerContextPtr +ServerContextSharedPtr ContextManagerImpl::createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names) { - ServerContextPtr context(new ServerContextImpl(*this, scope, config, server_names, runtime_)); - std::unique_lock lock(contexts_lock_); - contexts_.emplace_back(context.get()); + ServerContextSharedPtr context = + std::make_shared(scope, config, server_names, runtime_); + removeEmptyContexts(); + contexts_.emplace_back(context); return context; } size_t ContextManagerImpl::daysUntilFirstCertExpires() const { - std::shared_lock lock(contexts_lock_); size_t ret = std::numeric_limits::max(); - for (Context* context : contexts_) { - ret = std::min(context->daysUntilFirstCertExpires(), ret); + for (const auto& ctx_weak_ptr : contexts_) { + ContextSharedPtr context = ctx_weak_ptr.lock(); + if (context) { + ret = std::min(context->daysUntilFirstCertExpires(), ret); + } } return ret; } void ContextManagerImpl::iterateContexts(std::function callback) { - std::shared_lock lock(contexts_lock_); - for (Context* context : contexts_) { - callback(*context); + for (const auto& ctx_weak_ptr : contexts_) { + ContextSharedPtr context = ctx_weak_ptr.lock(); + if (context) { + callback(*context); + } } } diff --git a/source/common/ssl/context_manager_impl.h b/source/common/ssl/context_manager_impl.h index bd31db4f22008..d330e0dabed4f 100644 --- a/source/common/ssl/context_manager_impl.h +++ b/source/common/ssl/context_manager_impl.h @@ -22,26 +22,19 @@ class ContextManagerImpl final : public ContextManager { ContextManagerImpl(Runtime::Loader& runtime) : runtime_(runtime) {} ~ContextManagerImpl(); - /** - * Allocated contexts are owned by the caller. However, we need to be able to iterate them for - * admin purposes. When a caller frees a context it will tell us to release it also from the list - * of contexts. - */ - void releaseContext(Context* context); - // Ssl::ContextManager - Ssl::ClientContextPtr createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) override; - Ssl::ServerContextPtr + Ssl::ClientContextSharedPtr createSslClientContext(Stats::Scope& scope, + const ClientContextConfig& config) override; + Ssl::ServerContextSharedPtr createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names) override; size_t daysUntilFirstCertExpires() const override; void iterateContexts(std::function callback) override; private: + void removeEmptyContexts(); Runtime::Loader& runtime_; - std::list contexts_; - mutable std::shared_timed_mutex contexts_lock_; + std::list> contexts_; }; } // namespace Ssl diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index c6b31e6293a2c..ec6d9967c6aa8 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -15,8 +15,8 @@ using Envoy::Network::PostIoAction; namespace Envoy { namespace Ssl { -SslSocket::SslSocket(Context& ctx, InitialState state) - : ctx_(dynamic_cast(ctx)), ssl_(ctx_.newSsl()) { +SslSocket::SslSocket(ContextSharedPtr ctx, InitialState state) + : ctx_(std::dynamic_pointer_cast(ctx)), ssl_(ctx_->newSsl()) { if (state == InitialState::Client) { SSL_set_connect_state(ssl_.get()); } else { @@ -99,7 +99,7 @@ PostIoAction SslSocket::doHandshake() { if (rc == 1) { ENVOY_CONN_LOG(debug, "handshake complete", callbacks_->connection()); handshake_complete_ = true; - ctx_.logHandshake(ssl_.get()); + ctx_->logHandshake(ssl_.get()); callbacks_->raiseEvent(Network::ConnectionEvent::Connected); // It's possible that we closed during the handshake callback. @@ -126,7 +126,7 @@ void SslSocket::drainErrorQueue() { while (uint64_t err = ERR_get_error()) { if (ERR_GET_LIB(err) == ERR_LIB_SSL) { if (ERR_GET_REASON(err) == SSL_R_PEER_DID_NOT_RETURN_A_CERTIFICATE) { - ctx_.stats().fail_verify_no_cert_.inc(); + ctx_->stats().fail_verify_no_cert_.inc(); saw_counted_error = true; } else if (ERR_GET_REASON(err) == SSL_R_CERTIFICATE_VERIFY_FAILED) { saw_counted_error = true; @@ -139,7 +139,7 @@ void SslSocket::drainErrorQueue() { ERR_reason_error_string(err)); } if (saw_error && !saw_counted_error) { - ctx_.stats().connection_error_.inc(); + ctx_->stats().connection_error_.inc(); } } @@ -388,7 +388,7 @@ ClientSslSocketFactory::ClientSslSocketFactory(const ClientContextConfig& config : ssl_ctx_(manager.createSslClientContext(stats_scope, config)) {} Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() const { - return std::make_unique(*ssl_ctx_, Ssl::InitialState::Client); + return std::make_unique(ssl_ctx_, Ssl::InitialState::Client); } bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } @@ -400,7 +400,7 @@ ServerSslSocketFactory::ServerSslSocketFactory(const ServerContextConfig& config : ssl_ctx_(manager.createSslServerContext(stats_scope, config, server_names)) {} Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() const { - return std::make_unique(*ssl_ctx_, Ssl::InitialState::Server); + return std::make_unique(ssl_ctx_, Ssl::InitialState::Server); } bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index be4fa9aeb4992..68fec106eb916 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -20,7 +20,7 @@ class SslSocket : public Network::TransportSocket, public Connection, protected Logger::Loggable { public: - SslSocket(Context& ctx, InitialState state); + SslSocket(ContextSharedPtr ctx, InitialState state); // Ssl::Connection bool peerCertificatePresented() const override; @@ -58,7 +58,7 @@ class SslSocket : public Network::TransportSocket, std::vector getDnsSansFromCertificate(X509* cert); Network::TransportSocketCallbacks* callbacks_{}; - ContextImpl& ctx_; + ContextImplSharedPtr ctx_; bssl::UniquePtr ssl_; bool handshake_complete_{}; bool shutdown_sent_{}; @@ -71,22 +71,24 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory { public: ClientSslSocketFactory(const ClientContextConfig& config, Ssl::ContextManager& manager, Stats::Scope& stats_scope); + Network::TransportSocketPtr createTransportSocket() const override; bool implementsSecureTransport() const override; private: - const ClientContextPtr ssl_ctx_; + ClientContextSharedPtr ssl_ctx_; }; class ServerSslSocketFactory : public Network::TransportSocketFactory { public: ServerSslSocketFactory(const ServerContextConfig& config, Ssl::ContextManager& manager, Stats::Scope& stats_scope, const std::vector& server_names); + Network::TransportSocketPtr createTransportSocket() const override; bool implementsSecureTransport() const override; private: - const ServerContextPtr ssl_ctx_; + ServerContextSharedPtr ssl_ctx_; }; } // namespace Ssl diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index f6c9c3fa152c6..0ce4ae7e44930 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -399,6 +399,7 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { Upstream::MockThreadLocalCluster thread_local_cluster_; NiceMock local_info_; Runtime::MockLoader runtime_; + Ssl::ContextManagerImpl context_manager_{runtime_}; NiceMock random_; Http::AsyncClientPtr http_async_client_; Http::ConnectionPool::InstancePtr http_conn_pool_; @@ -421,6 +422,7 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { // doesn't like dangling contexts at destruction. GrpcClientIntegrationTest::TearDown(); fake_upstream_.reset(); + async_client_transport_socket_.reset(); client_connection_.reset(); mock_cluster_info_->transport_socket_factory_.reset(); } @@ -483,7 +485,6 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { bool use_client_cert_{}; Secret::MockSecretManager secret_manager_; - Ssl::ContextManagerImpl context_manager_{runtime_}; }; } // namespace diff --git a/test/common/ssl/context_impl_test.cc b/test/common/ssl/context_impl_test.cc index de82f328a7570..8efaad4f499fa 100644 --- a/test/common/ssl/context_impl_test.cc +++ b/test/common/ssl/context_impl_test.cc @@ -99,7 +99,7 @@ TEST_F(SslContextImplTest, TestExpiringCert) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); // This is a total hack, but right now we generate the cert and it expires in 15 days only in the // first second that it's valid. This can become invalid and then cause slower tests to fail. @@ -122,7 +122,7 @@ TEST_F(SslContextImplTest, TestExpiredCert) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); EXPECT_EQ(0U, context->daysUntilFirstCertExpires()); } @@ -141,7 +141,7 @@ TEST_F(SslContextImplTest, TestGetCertInformation) { ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); // This is similar to the hack above, but right now we generate the ca_cert and it expires in 15 // days only in the first second that it's valid. We will partially match for up until Days until // Expiration: 1. @@ -166,7 +166,7 @@ TEST_F(SslContextImplTest, TestNoCert) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); EXPECT_EQ("", context->getCaCertInformation()); EXPECT_EQ("", context->getCertChainInformation()); } @@ -178,7 +178,7 @@ class SslServerContextImplTicketTest : public SslContextImplTest { Secret::MockSecretManager secret_manager; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ServerContextPtr server_ctx( + ServerContextSharedPtr server_ctx( manager.createSslServerContext(store, cfg, std::vector{})); } @@ -500,7 +500,7 @@ TEST(ServerContextImplTest, TlsCertificateNonEmpty) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - EXPECT_THROW_WITH_MESSAGE(ServerContextPtr server_ctx(manager.createSslServerContext( + EXPECT_THROW_WITH_MESSAGE(ServerContextSharedPtr server_ctx(manager.createSslServerContext( store, client_context_config, std::vector{})), EnvoyException, "Server TlsCertificates must have a certificate specified"); diff --git a/test/common/ssl/ssl_socket_test.cc b/test/common/ssl/ssl_socket_test.cc index 0c24760e04eee..35925d673ef75 100644 --- a/test/common/ssl/ssl_socket_test.cc +++ b/test/common/ssl/ssl_socket_test.cc @@ -170,7 +170,7 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, ClientContextConfigImpl client_ctx_config(client_ctx_proto, secret_manager); ClientSslSocketFactory client_ssl_socket_factory(client_ctx_config, manager, stats_store); - ClientContextPtr client_ctx(manager.createSslClientContext(stats_store, client_ctx_config)); + ClientContextSharedPtr client_ctx(manager.createSslClientContext(stats_store, client_ctx_config)); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(), nullptr); @@ -2754,7 +2754,7 @@ class SslReadBufferLimitTest : public SslCertsTest, Network::ListenerPtr listener_; Json::ObjectSharedPtr client_ctx_loader_; std::unique_ptr client_ctx_config_; - ClientContextPtr client_ctx_; + ClientContextSharedPtr client_ctx_; Network::TransportSocketFactoryPtr client_ssl_socket_factory_; Network::ClientConnectionPtr client_connection_; Network::ConnectionPtr server_connection_; diff --git a/test/integration/ssl_integration_test.cc b/test/integration/ssl_integration_test.cc index 2cbd941cfa9a4..3ed603e9b37c6 100644 --- a/test/integration/ssl_integration_test.cc +++ b/test/integration/ssl_integration_test.cc @@ -50,6 +50,8 @@ void SslIntegrationTest::TearDown() { client_ssl_ctx_alpn_.reset(); client_ssl_ctx_san_.reset(); client_ssl_ctx_alpn_san_.reset(); + HttpIntegrationTest::cleanupUpstreamAndDownstream(); + codec_client_.reset(); context_manager_.reset(); runtime_.reset(); } diff --git a/test/integration/xfcc_integration_test.cc b/test/integration/xfcc_integration_test.cc index f0d4f70eda79d..fa2116055374b 100644 --- a/test/integration/xfcc_integration_test.cc +++ b/test/integration/xfcc_integration_test.cc @@ -31,6 +31,8 @@ void XfccIntegrationTest::TearDown() { client_tls_ssl_ctx_.reset(); fake_upstream_connection_.reset(); fake_upstreams_.clear(); + HttpIntegrationTest::cleanupUpstreamAndDownstream(); + codec_client_.reset(); context_manager_.reset(); runtime_.reset(); } diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index d93daf7005b08..033e58afd9d5f 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -21,21 +21,11 @@ class MockContextManager : public ContextManager { MockContextManager(); ~MockContextManager(); - ClientContextPtr createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) override { - return ClientContextPtr{createSslClientContext_(scope, config)}; - } - - ServerContextPtr createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, - const std::vector& server_names) override { - return ServerContextPtr{createSslServerContext_(scope, config, server_names)}; - } - - MOCK_METHOD2(createSslClientContext_, - ClientContext*(Stats::Scope& scope, const ClientContextConfig& config)); - MOCK_METHOD3(createSslServerContext_, - ServerContext*(Stats::Scope& stats, const ServerContextConfig& config, - const std::vector& server_names)); + MOCK_METHOD2(createSslClientContext, + ClientContextSharedPtr(Stats::Scope& scope, const ClientContextConfig& config)); + MOCK_METHOD3(createSslServerContext, + ServerContextSharedPtr(Stats::Scope& stats, const ServerContextConfig& config, + const std::vector& server_names)); MOCK_CONST_METHOD0(daysUntilFirstCertExpires, size_t()); MOCK_METHOD1(iterateContexts, void(std::function callback)); };