diff --git a/include/envoy/event/dispatcher.h b/include/envoy/event/dispatcher.h index b912307e6c232..4b4671f4528e0 100644 --- a/include/envoy/event/dispatcher.h +++ b/include/envoy/event/dispatcher.h @@ -42,16 +42,18 @@ class Dispatcher { * Create a server connection. * @param socket supplies an open file descriptor and connection metadata to use for the * connection. Takes ownership of the socket. - * @param ssl_ctx supplies the SSL context to use, if not nullptr. + * @param transport_socket supplies a transport socket to be used by the connection. * @return Network::ConnectionPtr a server connection that is owned by the caller. */ - virtual Network::ConnectionPtr createServerConnection(Network::ConnectionSocketPtr&& socket, - Ssl::Context* ssl_ctx) PURE; + virtual Network::ConnectionPtr + createServerConnection(Network::ConnectionSocketPtr&& socket, + Network::TransportSocketPtr&& transport_socket) PURE; /** * Create a client connection. * @param address supplies the address to connect to. * @param source_address supplies an address to bind to or nullptr if no bind is necessary. + * @param transport_socket supplies a transport socket to be used by the connection. * @param options the socket options to be set on the underlying socket before anything is sent * on the socket. * @return Network::ClientConnectionPtr a client connection that is owned by the caller. diff --git a/include/envoy/network/listener.h b/include/envoy/network/listener.h index 16e1327840e09..7a94528c4b3f8 100644 --- a/include/envoy/network/listener.h +++ b/include/envoy/network/listener.h @@ -7,6 +7,7 @@ #include "envoy/common/exception.h" #include "envoy/network/connection.h" #include "envoy/network/listen_socket.h" +#include "envoy/network/transport_socket.h" #include "envoy/ssl/context.h" namespace Envoy { @@ -32,9 +33,9 @@ class ListenerConfig { virtual ListenSocket& socket() PURE; /** - * @return Ssl::ServerContext* the default SSL context. + * @return TransportSocketFactory& the transport socket factory. */ - virtual Ssl::ServerContext* defaultSslContext() PURE; + virtual TransportSocketFactory& transportSocketFactory() PURE; /** * @return bool specifies whether the listener should actually listen on the port. diff --git a/include/envoy/server/transport_socket_config.h b/include/envoy/server/transport_socket_config.h index f0ce950ada738..fa6b671c7ff73 100644 --- a/include/envoy/server/transport_socket_config.h +++ b/include/envoy/server/transport_socket_config.h @@ -33,11 +33,31 @@ class TransportSocketConfigFactory { public: virtual ~TransportSocketConfigFactory() {} + /** + * @return ProtobufTypes::MessagePtr create empty config proto message. The transport socket + * config, which arrives in an opaque google.protobuf.Struct message, will be converted + * to JSON and then parsed into this empty proto. + */ + virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; + + /** + * @return std::string the identifying name for a particular TransportSocketFactoryPtr + * implementation produced by the factory. + */ + virtual std::string name() const PURE; +}; + +/** + * Implemented by each transport socket used for upstream connections. Registered via class + * RegisterFactory. + */ +class UpstreamTransportSocketConfigFactory : public virtual TransportSocketConfigFactory { +public: /** * Create a particular transport socket factory implementation. * @param config const Protobuf::Message& supplies the config message for the transport socket * implementation. - * @param context TransportSocketFactoryContext& supplies the transport socket's context. + * @param context TransportSocketFactoryContext& supplies the transport socket's context. * @return Network::TransportSocketFactoryPtr the transport socket factory instance. The returned * TransportSocketFactoryPtr should not be nullptr. * @@ -47,23 +67,40 @@ class TransportSocketConfigFactory { virtual Network::TransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, TransportSocketFactoryContext& context) PURE; +}; +/** + * Implemented by each transport socket used for downstream connections. Registered via class + * RegisterFactory. + */ +class DownstreamTransportSocketConfigFactory : public virtual TransportSocketConfigFactory { +public: /** - * @return ProtobufTypes::MessagePtr create empty config proto message. The transport socket - * config, which arrives in an opaque google.protobuf.Struct message, will be converted - * to JSON and then parsed into this empty proto. - */ - virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; - - /** - * @return std::string the identifying name for a particular TransportSocketFactoryPtr - * implementation produced by the factory. + * Create a particular downstream transport socket factory implementation. + * TODO(lizan): Revisit the parameters for SNI below when TLS sniffing and filter chain match are + * implemented. + * @param listener_name const std::string& the name of the listener. + * @param server_names const std::vector& the names of the server. This parameter is + * currently used by SNI implementation to know the expected server names. + * @param skip_ssl_context_update bool indicates whether the ssl context update should be skipped. + * This parameter is currently used by SNI implementation to know whether it should perform + * certificate selection. + * @param config const Protobuf::Message& supplies the config message for the transport socket + * implementation. + * @param context TransportSocketFactoryContext& supplies the transport socket's context. + * @return Network::TransportSocketFactoryPtr the transport socket factory instance. The returned + * TransportSocketFactoryPtr should not be nullptr. + * + * @throw EnvoyException if the implementation is unable to produce a factory with the provided + * parameters. */ - virtual std::string name() const PURE; + virtual Network::TransportSocketFactoryPtr + createTransportSocketFactory(const std::string& listener_name, + const std::vector& server_names, + bool skip_ssl_context_update, const Protobuf::Message& config, + TransportSocketFactoryContext& context) PURE; }; -class UpstreamTransportSocketConfigFactory : public virtual TransportSocketConfigFactory {}; - } // namespace Configuration } // namespace Server } // namespace Envoy diff --git a/source/common/event/dispatcher_impl.cc b/source/common/event/dispatcher_impl.cc index 44ae6b0fb76cc..bef41ccdc8424 100644 --- a/source/common/event/dispatcher_impl.cc +++ b/source/common/event/dispatcher_impl.cc @@ -73,13 +73,12 @@ void DispatcherImpl::clearDeferredDeleteList() { deferred_deleting_ = false; } -Network::ConnectionPtr DispatcherImpl::createServerConnection(Network::ConnectionSocketPtr&& socket, - Ssl::Context* ssl_ctx) { +Network::ConnectionPtr +DispatcherImpl::createServerConnection(Network::ConnectionSocketPtr&& socket, + Network::TransportSocketPtr&& transport_socket) { ASSERT(isThreadSafe()); - return Network::ConnectionPtr{ssl_ctx - ? new Ssl::ConnectionImpl(*this, std::move(socket), true, - *ssl_ctx, Ssl::InitialState::Server) - : new Network::ConnectionImpl(*this, std::move(socket), true)}; + return std::make_unique(*this, std::move(socket), + std::move(transport_socket), true); } Network::ClientConnectionPtr diff --git a/source/common/event/dispatcher_impl.h b/source/common/event/dispatcher_impl.h index 5b920041a31ea..500f339c68ecd 100644 --- a/source/common/event/dispatcher_impl.h +++ b/source/common/event/dispatcher_impl.h @@ -33,8 +33,9 @@ class DispatcherImpl : Logger::Loggable, public Dispatcher { // Event::Dispatcher void clearDeferredDeleteList() override; - Network::ConnectionPtr createServerConnection(Network::ConnectionSocketPtr&& socket, - Ssl::Context* ssl_ctx) override; + Network::ConnectionPtr + createServerConnection(Network::ConnectionSocketPtr&& socket, + Network::TransportSocketPtr&& transport_socket) override; Network::ClientConnectionPtr createClientConnection(Network::Address::InstanceConstSharedPtr address, Network::Address::InstanceConstSharedPtr source_address, diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index 9fd5e0cae115e..3d7f11ec20813 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -350,5 +350,20 @@ Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() cons bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } +ServerSslSocketFactory::ServerSslSocketFactory(const ServerContextConfig& config, + const std::string& listener_name, + const std::vector& server_names, + bool skip_context_update, + Ssl::ContextManager& manager, + Stats::Scope& stats_scope) + : ssl_ctx_(manager.createSslServerContext(listener_name, server_names, stats_scope, config, + skip_context_update)) {} + +Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() const { + return std::make_unique(*ssl_ctx_, Ssl::InitialState::Server); +} + +bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } + } // namespace Ssl } // namespace Envoy diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index 6437fa89c54a8..48ce68a4c63f9 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -63,7 +63,19 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory { bool implementsSecureTransport() const override; private: - ClientContextPtr ssl_ctx_; + const ClientContextPtr ssl_ctx_; +}; + +class ServerSslSocketFactory : public Network::TransportSocketFactory { +public: + ServerSslSocketFactory(const ServerContextConfig& config, const std::string& listener_name, + const std::vector& server_names, bool skip_context_update, + Ssl::ContextManager& manager, Stats::Scope& stats_scope); + Network::TransportSocketPtr createTransportSocket() const override; + bool implementsSecureTransport() const override; + +private: + const ServerContextPtr ssl_ctx_; }; } // namespace Ssl diff --git a/source/common/upstream/upstream_impl.cc b/source/common/upstream/upstream_impl.cc index 0a2c325aeb70b..4bb16fa39bfcf 100644 --- a/source/common/upstream/upstream_impl.cc +++ b/source/common/upstream/upstream_impl.cc @@ -118,6 +118,9 @@ ClusterInfoImpl::ClusterInfoImpl(const envoy::api::v2::Cluster& config, lb_subset_(LoadBalancerSubsetInfoImpl(config.lb_subset_config())), metadata_(config.metadata()) { + // If the cluster doesn't have transport socke configured, override with default transport + // socket implementation based on tls_context. We copy by value first then override if + // neccessary. auto transport_socket = config.transport_socket(); if (!config.has_transport_socket()) { if (config.has_tls_context()) { diff --git a/source/server/BUILD b/source/server/BUILD index ee6885628c3ef..c83baeea051b5 100644 --- a/source/server/BUILD +++ b/source/server/BUILD @@ -199,6 +199,7 @@ envoy_cc_library( ":init_manager_lib", "//include/envoy/server:filter_config_interface", "//include/envoy/server:listener_manager_interface", + "//include/envoy/server:transport_socket_config_interface", "//include/envoy/server:worker_interface", "//source/common/config:utility_lib", "//source/common/config:well_known_names", diff --git a/source/server/config/network/raw_buffer_socket.cc b/source/server/config/network/raw_buffer_socket.cc index 6620911e7fdf9..77009a9c15e86 100644 --- a/source/server/config/network/raw_buffer_socket.cc +++ b/source/server/config/network/raw_buffer_socket.cc @@ -9,8 +9,14 @@ namespace Server { namespace Configuration { Network::TransportSocketFactoryPtr -RawBufferSocketFactory::createTransportSocketFactory(const Protobuf::Message&, - TransportSocketFactoryContext&) { +UpstreamRawBufferSocketFactory::createTransportSocketFactory(const Protobuf::Message&, + TransportSocketFactoryContext&) { + return std::make_unique(); +} + +Network::TransportSocketFactoryPtr DownstreamRawBufferSocketFactory::createTransportSocketFactory( + const std::string&, const std::vector&, bool, const Protobuf::Message&, + TransportSocketFactoryContext&) { return std::make_unique(); } @@ -22,6 +28,10 @@ static Registry::RegisterFactory upstream_registered_; +static Registry::RegisterFactory + downstream_registered_; + } // namespace Configuration } // namespace Server } // namespace Envoy diff --git a/source/server/config/network/raw_buffer_socket.h b/source/server/config/network/raw_buffer_socket.h index 5edcd8410e1e2..63457bccfe373 100644 --- a/source/server/config/network/raw_buffer_socket.h +++ b/source/server/config/network/raw_buffer_socket.h @@ -16,14 +16,26 @@ class RawBufferSocketFactory : public virtual TransportSocketConfigFactory { public: virtual ~RawBufferSocketFactory() {} std::string name() const override { return Config::TransportSocketNames::get().RAW_BUFFER; } + ProtobufTypes::MessagePtr createEmptyConfigProto() override; +}; + +class UpstreamRawBufferSocketFactory : public UpstreamTransportSocketConfigFactory, + public RawBufferSocketFactory { +public: Network::TransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, TransportSocketFactoryContext& context) override; - ProtobufTypes::MessagePtr createEmptyConfigProto() override; }; -class UpstreamRawBufferSocketFactory : public UpstreamTransportSocketConfigFactory, - public RawBufferSocketFactory {}; +class DownstreamRawBufferSocketFactory : public DownstreamTransportSocketConfigFactory, + public RawBufferSocketFactory { +public: + Network::TransportSocketFactoryPtr + createTransportSocketFactory(const std::string& listener_name, + const std::vector& server_names, + bool skip_context_update, const Protobuf::Message& config, + TransportSocketFactoryContext& context) override; +}; } // namespace Configuration } // namespace Server diff --git a/source/server/config/network/ssl_socket.cc b/source/server/config/network/ssl_socket.cc index e4916d7392d82..22ff8610fab65 100644 --- a/source/server/config/network/ssl_socket.cc +++ b/source/server/config/network/ssl_socket.cc @@ -30,6 +30,25 @@ ProtobufTypes::MessagePtr UpstreamSslSocketFactory::createEmptyConfigProto() { static Registry::RegisterFactory upstream_registered_; +Network::TransportSocketFactoryPtr DownstreamSslSocketFactory::createTransportSocketFactory( + const std::string& listener_name, const std::vector& server_names, + bool skip_context_update, const Protobuf::Message& message, + TransportSocketFactoryContext& context) { + return std::make_unique( + Ssl::ServerContextConfigImpl( + MessageUtil::downcastAndValidate( + message)), + listener_name, server_names, skip_context_update, context.sslContextManager(), + context.statsScope()); +} + +ProtobufTypes::MessagePtr DownstreamSslSocketFactory::createEmptyConfigProto() { + return std::make_unique(); +} + +static Registry::RegisterFactory + downstream_registered_; + } // namespace Configuration } // namespace Server } // namespace Envoy diff --git a/source/server/config/network/ssl_socket.h b/source/server/config/network/ssl_socket.h index 083d0ee7777c1..983c3a1bd8dd1 100644 --- a/source/server/config/network/ssl_socket.h +++ b/source/server/config/network/ssl_socket.h @@ -27,6 +27,17 @@ class UpstreamSslSocketFactory : public UpstreamTransportSocketConfigFactory, ProtobufTypes::MessagePtr createEmptyConfigProto() override; }; +class DownstreamSslSocketFactory : public DownstreamTransportSocketConfigFactory, + public SslSocketConfigFactory { +public: + Network::TransportSocketFactoryPtr + createTransportSocketFactory(const std::string& listener_name, + const std::vector& server_names, + bool skip_context_update, const Protobuf::Message& config, + TransportSocketFactoryContext& context) override; + ProtobufTypes::MessagePtr createEmptyConfigProto() override; +}; + } // namespace Configuration } // namespace Server } // namespace Envoy diff --git a/source/server/connection_handler_impl.cc b/source/server/connection_handler_impl.cc index 444bf40fae606..89d807e966418 100644 --- a/source/server/connection_handler_impl.cc +++ b/source/server/connection_handler_impl.cc @@ -184,8 +184,8 @@ void ConnectionHandlerImpl::ActiveListener::onAccept( } void ConnectionHandlerImpl::ActiveListener::newConnection(Network::ConnectionSocketPtr&& socket) { - Network::ConnectionPtr new_connection = - parent_.dispatcher_.createServerConnection(std::move(socket), config_.defaultSslContext()); + Network::ConnectionPtr new_connection = parent_.dispatcher_.createServerConnection( + std::move(socket), config_.transportSocketFactory().createTransportSocket()); new_connection->setBufferLimits(config_.perConnectionBufferLimitBytes()); onNewConnection(std::move(new_connection)); } diff --git a/source/server/http/BUILD b/source/server/http/BUILD index 88988cc11c348..23514bd8a6808 100644 --- a/source/server/http/BUILD +++ b/source/server/http/BUILD @@ -44,6 +44,7 @@ envoy_cc_library( "//source/common/http:utility_lib", "//source/common/http/http1:codec_lib", "//source/common/network:listen_socket_lib", + "//source/common/network:raw_buffer_socket_lib", "//source/common/profiler:profiler_lib", "//source/common/router:config_lib", "//source/common/upstream:host_utility_lib", diff --git a/source/server/http/admin.h b/source/server/http/admin.h index c329993faae0b..ddc351004bb54 100644 --- a/source/server/http/admin.h +++ b/source/server/http/admin.h @@ -21,6 +21,7 @@ #include "common/http/conn_manager_impl.h" #include "common/http/date_provider_impl.h" #include "common/http/utility.h" +#include "common/network/raw_buffer_socket.h" #include "server/config/network/http_connection_manager.h" @@ -174,7 +175,9 @@ class AdminImpl : public Admin, // Network::ListenerConfig Network::FilterChainFactory& filterChainFactory() override { return parent_; } Network::ListenSocket& socket() override { return parent_.mutable_socket(); } - Ssl::ServerContext* defaultSslContext() override { return nullptr; } + Network::TransportSocketFactory& transportSocketFactory() override { + return parent_.transport_socket_factory_; + } bool bindToPort() override { return true; } bool handOffRestoredDestinationConnections() const override { return false; } uint32_t perConnectionBufferLimitBytes() override { return 0; } @@ -192,6 +195,7 @@ class AdminImpl : public Admin, std::list access_logs_; const std::string profile_path_; Network::ListenSocketPtr socket_; + Network::RawBufferSocketFactory transport_socket_factory_; Http::ConnectionManagerStats stats_; Http::ConnectionManagerTracingStats tracing_stats_; NullRouteConfigProvider route_config_provider_; diff --git a/source/server/listener_manager_impl.cc b/source/server/listener_manager_impl.cc index 6a19315850457..b712bbdff2aaf 100644 --- a/source/server/listener_manager_impl.cc +++ b/source/server/listener_manager_impl.cc @@ -1,6 +1,7 @@ #include "server/listener_manager_impl.h" #include "envoy/registry/registry.h" +#include "envoy/server/transport_socket_config.h" #include "common/common/assert.h" #include "common/common/fmt.h" @@ -9,7 +10,6 @@ #include "common/network/listen_socket_impl.h" #include "common/network/utility.h" #include "common/protobuf/utility.h" -#include "common/ssl/context_config_impl.h" #include "server/configuration_impl.h" #include "server/drain_manager_impl.h" @@ -166,16 +166,41 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, ListenerManag "is currently not supported", address_->asString())); } - if (filter_chain.has_tls_context()) { - Ssl::ServerContextConfigImpl context_config(filter_chain.tls_context()); - tls_contexts_.emplace_back(parent_.server_.sslContextManager().createSslServerContext( - name_, sni_domains, *listener_scope_, context_config, skip_context_update)); - has_tls++; - if (filter_chain.tls_context().has_session_ticket_keys()) { - has_stk++; + + // If the cluster doesn't have transport socke configured, override with default transport + // socket implementation based on tls_context. We copy by value first then override if + // neccessary. + auto transport_socket = filter_chain.transport_socket(); + if (!filter_chain.has_transport_socket()) { + if (filter_chain.has_tls_context()) { + transport_socket.set_name(Config::TransportSocketNames::get().SSL); + MessageUtil::jsonConvert(filter_chain.tls_context(), *transport_socket.mutable_config()); + + has_tls++; + if (filter_chain.tls_context().has_session_ticket_keys()) { + has_stk++; + } + } else { + transport_socket.set_name(Config::TransportSocketNames::get().RAW_BUFFER); } } + + auto& config_factory = Config::Utility::getAndCheckFactory< + Server::Configuration::DownstreamTransportSocketConfigFactory>(transport_socket.name()); + ProtobufTypes::MessagePtr message = + Config::Utility::translateToFactoryConfig(transport_socket, config_factory); + + // Each transport socket factory owns one SslServerContext, we need to store them all in a + // vector since Ssl::ContextManager doesn't own SslServerContext. While transportSocketFacotry() + // always returns the first element of transport_socket_factories_, other transport socket + // factories are needed when the default Ssl::ServerContext updates SSL context based on + // ClientHello. This behavior is a workaround for initial SNI support before the full SNI based + // filter chain match is implemented. + transport_socket_factories_.emplace_back(config_factory.createTransportSocketFactory( + name_, sni_domains, skip_context_update, *message, *this)); + ASSERT(transport_socket_factories_.back() != nullptr); } + ASSERT(!transport_socket_factories_.empty()); // TODO(PiotrSikora): allow filter chains with mixed use of Session Ticket Keys. // This doesn't work right now, because BoringSSL uses "session context" (initial SSL_CTX that diff --git a/source/server/listener_manager_impl.h b/source/server/listener_manager_impl.h index 3a35e27ae5718..351d92f3bee44 100644 --- a/source/server/listener_manager_impl.h +++ b/source/server/listener_manager_impl.h @@ -4,6 +4,7 @@ #include "envoy/server/filter_config.h" #include "envoy/server/instance.h" #include "envoy/server/listener_manager.h" +#include "envoy/server/transport_socket_config.h" #include "envoy/server/worker.h" #include "common/common/logger.h" @@ -168,6 +169,7 @@ class ListenerImpl : public Network::ListenerConfig, public Configuration::FactoryContext, public Network::DrainDecision, public Network::FilterChainFactory, + public Configuration::TransportSocketFactoryContext, Logger::Loggable { public: /** @@ -214,8 +216,8 @@ class ListenerImpl : public Network::ListenerConfig, bool handOffRestoredDestinationConnections() const override { return hand_off_restored_destination_connections_; } - Ssl::ServerContext* defaultSslContext() override { - return tls_contexts_.empty() ? nullptr : tls_contexts_[0].get(); + Network::TransportSocketFactory& transportSocketFactory() override { + return *transport_socket_factories_[0]; } uint32_t perConnectionBufferLimitBytes() override { return per_connection_buffer_limit_bytes_; } Stats::Scope& listenerScope() override { return *listener_scope_; } @@ -252,6 +254,10 @@ class ListenerImpl : public Network::ListenerConfig, bool createNetworkFilterChain(Network::Connection& connection) override; bool createListenerFilterChain(Network::ListenerFilterManager& manager) override; + // Configuration::TransportSocketFactoryContext + Ssl::ContextManager& sslContextManager() override { return parent_.server_.sslContextManager(); } + Stats::Scope& statsScope() const override { return *listener_scope_; } + private: ListenerManagerImpl& parent_; Network::Address::InstanceConstSharedPtr address_; @@ -259,6 +265,7 @@ class ListenerImpl : public Network::ListenerConfig, Stats::ScopePtr global_scope_; // Stats with global named scope, but needed for LDS cleanup. Stats::ScopePtr listener_scope_; // Stats with listener named scope. std::vector tls_contexts_; + std::vector transport_socket_factories_; const bool bind_to_port_; const bool hand_off_restored_destination_connections_; const uint32_t per_connection_buffer_limit_bytes_; diff --git a/test/common/http/codec_client_test.cc b/test/common/http/codec_client_test.cc index e8d3ce484e13d..6d4e230daf207 100644 --- a/test/common/http/codec_client_test.cc +++ b/test/common/http/codec_client_test.cc @@ -190,8 +190,8 @@ class CodecNetworkTest : public testing::TestWithParam void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks_.onNewConnection(std::move(new_connection)); })); diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index 1061b5d1d836f..7e0a7e2e4f622 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -98,8 +98,8 @@ class ConnectionImplTest : public testing::TestWithParam { read_filter_.reset(new NiceMock()); EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks_.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) @@ -193,8 +193,8 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks_.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) @@ -257,8 +257,8 @@ TEST_P(ConnectionImplTest, SocketOptions) { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { socket->setOptions(options); - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks_.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) @@ -305,8 +305,8 @@ TEST_P(ConnectionImplTest, SocketOptionsFailureTest) { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { socket->setOptions(options); - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks_.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) @@ -369,8 +369,8 @@ TEST_P(ConnectionImplTest, ConnectionStats) { MockConnectionStats server_connection_stats; EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks_.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) @@ -884,8 +884,8 @@ class ReadBufferLimitTest : public ConnectionImplTest { read_filter_.reset(new NiceMock()); EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); new_connection->setBufferLimits(read_buffer_limit); listener_callbacks_.onNewConnection(std::move(new_connection)); })); diff --git a/test/common/network/dns_impl_test.cc b/test/common/network/dns_impl_test.cc index 28accb4de76bc..f0d80680bb5ff 100644 --- a/test/common/network/dns_impl_test.cc +++ b/test/common/network/dns_impl_test.cc @@ -201,8 +201,8 @@ class TestDnsServer : public ListenerCallbacks { TestDnsServer(Event::DispatcherImpl& dispatcher) : dispatcher_(dispatcher) {} void onAccept(ConnectionSocketPtr&& socket, bool) override { - Network::ConnectionPtr new_connection = - dispatcher_.createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher_.createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); onNewConnection(std::move(new_connection)); } diff --git a/test/common/network/listener_impl_test.cc b/test/common/network/listener_impl_test.cc index 2887922e0b7ad..860947807c176 100644 --- a/test/common/network/listener_impl_test.cc +++ b/test/common/network/listener_impl_test.cc @@ -37,8 +37,8 @@ static void errorCallbackTest(Address::IpVersion version) { EXPECT_CALL(listener_callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks, onNewConnection_(_)) @@ -104,8 +104,8 @@ TEST_P(ListenerImplTest, UseActualDst) { EXPECT_CALL(listener_callbacks2, onAccept_(_, _)).Times(0); EXPECT_CALL(listener_callbacks1, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks1.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks1, onNewConnection_(_)) @@ -139,8 +139,8 @@ TEST_P(ListenerImplTest, WildcardListenerUseActualDst) { EXPECT_CALL(listener_callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), nullptr); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), Network::Test::createRawBufferSocket()); listener_callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks, onNewConnection_(_)) diff --git a/test/common/network/proxy_protocol_test.cc b/test/common/network/proxy_protocol_test.cc index 4331d7f669b3a..691ac82d5239d 100644 --- a/test/common/network/proxy_protocol_test.cc +++ b/test/common/network/proxy_protocol_test.cc @@ -7,6 +7,7 @@ #include "common/filter/listener/proxy_protocol.h" #include "common/network/listen_socket_impl.h" #include "common/network/listener_impl.h" +#include "common/network/raw_buffer_socket.h" #include "common/network/utility.h" #include "common/stats/stats_impl.h" @@ -51,7 +52,9 @@ class ProxyProtocolTest : public testing::TestWithParam, // Listener Network::FilterChainFactory& filterChainFactory() override { return factory_; } Network::ListenSocket& socket() override { return socket_; } - Ssl::ServerContext* defaultSslContext() override { return nullptr; } + Network::TransportSocketFactory& transportSocketFactory() override { + return transport_socket_factory_; + } bool bindToPort() override { return true; } bool handOffRestoredDestinationConnections() const override { return false; } uint32_t perConnectionBufferLimitBytes() override { return 0; } @@ -123,6 +126,7 @@ class ProxyProtocolTest : public testing::TestWithParam, Event::DispatcherImpl dispatcher_; TcpListenSocket socket_; + Network::RawBufferSocketFactory transport_socket_factory_; Stats::IsolatedStoreImpl stats_store_; Network::ConnectionHandlerPtr connection_handler_; Network::MockFilterChainFactory factory_; @@ -338,7 +342,7 @@ class WildcardProxyProtocolTest : public testing::TestWithParamconnect(); Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), server_ssl_socket_factory.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -136,17 +136,17 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, ContextManagerImpl manager(runtime); std::string new_session = EMPTY_STRING; - std::vector server_contexts; + std::vector server_transport_socket_factories; for (const auto& filter_chain : server_proto.filter_chains()) { if (filter_chain.has_tls_context()) { std::vector sni_domains(filter_chain.filter_chain_match().sni_domains().begin(), filter_chain.filter_chain_match().sni_domains().end()); Ssl::ServerContextConfigImpl server_ctx_config(filter_chain.tls_context()); - server_contexts.emplace_back(manager.createSslServerContext( - "test_listener", sni_domains, stats_store, server_ctx_config, false)); + server_transport_socket_factories.emplace_back(new Ssl::ServerSslSocketFactory( + server_ctx_config, "test_listener", sni_domains, false, manager, stats_store)); } } - ASSERT(server_contexts.size() >= 1); + ASSERT(server_transport_socket_factories.size() >= 1); Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(version), true); @@ -155,11 +155,11 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); ClientContextConfigImpl client_ctx_config(client_ctx_proto); - ClientSslSocketFactory ssl_socket_factory(client_ctx_config, manager, stats_store); + ClientSslSocketFactory client_ssl_socket_factory(client_ctx_config, manager, stats_store); ClientContextPtr client_ctx(manager.createSslClientContext(stats_store, client_ctx_config)); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - ssl_socket_factory.createTransportSocket(), nullptr); + client_ssl_socket_factory.createTransportSocket(), nullptr); if (!client_session.empty()) { Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); @@ -178,8 +178,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), server_contexts[0].get()); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), server_transport_socket_factories[0]->createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -725,8 +725,8 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { Json::ObjectSharedPtr server_ctx_loader = TestEnvironment::jsonLoadFromString(server_ctx_json); ServerContextConfigImpl server_ctx_config(*server_ctx_loader); ContextManagerImpl manager(runtime); - ServerContextPtr server_ctx( - manager.createSslServerContext("", {}, stats_store, server_ctx_config, true)); + Ssl::ServerSslSocketFactory server_ssl_socket_factory(server_ctx_config, "", {}, true, manager, + stats_store); Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); @@ -745,8 +745,8 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), server_ssl_socket_factory.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -781,8 +781,8 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { Json::ObjectSharedPtr server_ctx_loader = TestEnvironment::jsonLoadFromString(server_ctx_json); ServerContextConfigImpl server_ctx_config(*server_ctx_loader); ContextManagerImpl manager(runtime); - ServerContextPtr server_ctx( - manager.createSslServerContext("", {}, stats_store, server_ctx_config, true)); + Ssl::ServerSslSocketFactory server_ssl_socket_factory(server_ctx_config, "", {}, true, manager, + stats_store); Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); @@ -821,8 +821,8 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), server_ssl_socket_factory.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -859,10 +859,10 @@ void testTicketSessionResumption(const std::string& server_ctx_json1, Json::ObjectSharedPtr server_ctx_loader2 = TestEnvironment::jsonLoadFromString(server_ctx_json2); ServerContextConfigImpl server_ctx_config1(*server_ctx_loader1); ServerContextConfigImpl server_ctx_config2(*server_ctx_loader2); - ServerContextPtr server_ctx1( - manager.createSslServerContext("server1", {}, stats_store, server_ctx_config1, false)); - ServerContextPtr server_ctx2( - manager.createSslServerContext("server2", {}, stats_store, server_ctx_config2, false)); + Ssl::ServerSslSocketFactory server_ssl_socket_factory1(server_ctx_config1, "server1", {}, false, + manager, stats_store); + Ssl::ServerSslSocketFactory server_ssl_socket_factory2(server_ctx_config2, "server2", {}, false, + manager, stats_store); Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket1(Network::Test::getCanonicalLoopbackAddress(ip_version), true); @@ -887,10 +887,11 @@ void testTicketSessionResumption(const std::string& server_ctx_json1, Network::ConnectionPtr server_connection; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillRepeatedly(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - ServerContext* ctx = socket->localAddress() == socket1.localAddress() ? server_ctx1.get() - : server_ctx2.get(); + Network::TransportSocketFactory& tsf = socket->localAddress() == socket1.localAddress() + ? server_ssl_socket_factory1 + : server_ssl_socket_factory2; Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), ctx); + dispatcher.createServerConnection(std::move(socket), tsf.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -1180,10 +1181,10 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { Json::ObjectSharedPtr server2_ctx_loader = TestEnvironment::jsonLoadFromString(server2_ctx_json); ServerContextConfigImpl server2_ctx_config(*server2_ctx_loader); ContextManagerImpl manager(runtime); - ServerContextPtr server_ctx( - manager.createSslServerContext("server1", {}, stats_store, server_ctx_config, false)); - ServerContextPtr server2_ctx( - manager.createSslServerContext("server2", {}, stats_store, server2_ctx_config, false)); + Ssl::ServerSslSocketFactory server_ssl_socket_factory(server_ctx_config, "server1", {}, false, + manager, stats_store); + Ssl::ServerSslSocketFactory server2_ssl_socket_factory(server2_ctx_config, "server2", {}, false, + manager, stats_store); Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); @@ -1216,11 +1217,12 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillRepeatedly(Invoke([&](Network::ConnectionSocketPtr& accepted_socket, bool) -> void { - ServerContext* ctx = accepted_socket->localAddress() == socket.localAddress() - ? server_ctx.get() - : server2_ctx.get(); - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(accepted_socket), ctx); + Network::TransportSocketFactory& tsf = + accepted_socket->localAddress() == socket.localAddress() ? server_ssl_socket_factory + : server2_ssl_socket_factory; + + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(accepted_socket), tsf.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -1288,8 +1290,8 @@ TEST_P(SslSocketTest, SslError) { Json::ObjectSharedPtr server_ctx_loader = TestEnvironment::jsonLoadFromString(server_ctx_json); ServerContextConfigImpl server_ctx_config(*server_ctx_loader); ContextManagerImpl manager(runtime); - ServerContextPtr server_ctx( - manager.createSslServerContext("", {}, stats_store, server_ctx_config, true)); + Ssl::ServerSslSocketFactory server_ssl_socket_factory(server_ctx_config, "", {}, true, manager, + stats_store); Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); @@ -1308,8 +1310,8 @@ TEST_P(SslSocketTest, SslError) { Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), server_ssl_socket_factory.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -1927,7 +1929,8 @@ class SslReadBufferLimitTest : public SslCertsTest, server_ctx_loader_ = TestEnvironment::jsonLoadFromString(server_ctx_json_); server_ctx_config_.reset(new ServerContextConfigImpl(*server_ctx_loader_)); manager_.reset(new ContextManagerImpl(runtime_)); - server_ctx_ = manager_->createSslServerContext("", {}, stats_store_, *server_ctx_config_, true); + server_ssl_socket_factory_.reset( + new ServerSslSocketFactory(*server_ctx_config_, "", {}, true, *manager_, stats_store_)); listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true, false); @@ -1950,8 +1953,8 @@ class SslReadBufferLimitTest : public SslCertsTest, EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), server_ctx_.get()); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), server_ssl_socket_factory_->createTransportSocket()); new_connection->setBufferLimits(read_buffer_limit); listener_callbacks_.onNewConnection(std::move(new_connection)); })); @@ -2033,8 +2036,8 @@ class SslReadBufferLimitTest : public SslCertsTest, EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), server_ctx_.get()); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), server_ssl_socket_factory_->createTransportSocket()); new_connection->setBufferLimits(read_buffer_limit); listener_callbacks_.onNewConnection(std::move(new_connection)); })); @@ -2100,7 +2103,7 @@ class SslReadBufferLimitTest : public SslCertsTest, Json::ObjectSharedPtr server_ctx_loader_; std::unique_ptr server_ctx_config_; std::unique_ptr manager_; - ServerContextPtr server_ctx_; + Network::TransportSocketFactoryPtr server_ssl_socket_factory_; Network::ListenerPtr listener_; Json::ObjectSharedPtr client_ctx_loader_; std::unique_ptr client_ctx_config_; @@ -2150,8 +2153,8 @@ TEST_P(SslReadBufferLimitTest, TestBind) { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - Network::ConnectionPtr new_connection = - dispatcher_->createServerConnection(std::move(socket), server_ctx_.get()); + Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( + std::move(socket), server_ssl_socket_factory_->createTransportSocket()); new_connection->setBufferLimits(0); listener_callbacks_.onNewConnection(std::move(new_connection)); })); diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index e25e8d4ad0ead..77a9be32a3edb 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -14,7 +14,9 @@ #include "common/http/http2/codec_impl.h" #include "common/network/address_impl.h" #include "common/network/listen_socket_impl.h" +#include "common/network/raw_buffer_socket.h" #include "common/network/utility.h" +#include "common/ssl/ssl_socket.h" #include "server/connection_handler_impl.h" @@ -229,8 +231,8 @@ FakeStreamPtr FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_dis } FakeUpstream::FakeUpstream(const std::string& uds_path, FakeHttpConnection::Type type) - : FakeUpstream(nullptr, Network::ListenSocketPtr{new Network::UdsListenSocket(uds_path)}, - type) { + : FakeUpstream(Network::Test::createRawBufferSocketFactory(), + Network::ListenSocketPtr{new Network::UdsListenSocket(uds_path)}, type) { ENVOY_LOG(info, "starting fake server on unix domain socket {}", uds_path); } @@ -244,22 +246,24 @@ static Network::ListenSocketPtr makeTcpListenSocket(uint32_t port, FakeUpstream::FakeUpstream(uint32_t port, FakeHttpConnection::Type type, Network::Address::IpVersion version) - : FakeUpstream(nullptr, makeTcpListenSocket(port, version), type) { + : FakeUpstream(Network::Test::createRawBufferSocketFactory(), + makeTcpListenSocket(port, version), type) { ENVOY_LOG(info, "starting fake server on port {}. Address version is {}", this->localAddress()->ip()->port(), Network::Test::addressVersionAsString(version)); } -FakeUpstream::FakeUpstream(Ssl::ServerContext* ssl_ctx, uint32_t port, - FakeHttpConnection::Type type, Network::Address::IpVersion version) - : FakeUpstream(ssl_ctx, makeTcpListenSocket(port, version), type) { +FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, + uint32_t port, FakeHttpConnection::Type type, + Network::Address::IpVersion version) + : FakeUpstream(std::move(transport_socket_factory), makeTcpListenSocket(port, version), type) { ENVOY_LOG(info, "starting fake SSL server on port {}. Address version is {}", this->localAddress()->ip()->port(), Network::Test::addressVersionAsString(version)); } -FakeUpstream::FakeUpstream(Ssl::ServerContext* ssl_ctx, Network::ListenSocketPtr&& listen_socket, - FakeHttpConnection::Type type) - : http_type_(type), ssl_ctx_(ssl_ctx), socket_(std::move(listen_socket)), - api_(new Api::Impl(std::chrono::milliseconds(10000))), +FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, + Network::ListenSocketPtr&& listen_socket, FakeHttpConnection::Type type) + : http_type_(type), transport_socket_factory_(std::move(transport_socket_factory)), + socket_(std::move(listen_socket)), api_(new Api::Impl(std::chrono::milliseconds(10000))), dispatcher_(api_->allocateDispatcher()), handler_(new Server::ConnectionHandlerImpl(ENVOY_LOGGER(), *dispatcher_)), allow_unexpected_disconnects_(false), listener_(*this) { diff --git a/test/integration/fake_upstream.h b/test/integration/fake_upstream.h index f7986604f5891..3ae9b5a80a869 100644 --- a/test/integration/fake_upstream.h +++ b/test/integration/fake_upstream.h @@ -282,8 +282,8 @@ class FakeUpstream : Logger::Loggable, public Network::Filt public: FakeUpstream(const std::string& uds_path, FakeHttpConnection::Type type); FakeUpstream(uint32_t port, FakeHttpConnection::Type type, Network::Address::IpVersion version); - FakeUpstream(Ssl::ServerContext* ssl_ctx, uint32_t port, FakeHttpConnection::Type type, - Network::Address::IpVersion version); + FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, uint32_t port, + FakeHttpConnection::Type type, Network::Address::IpVersion version); ~FakeUpstream(); FakeHttpConnection::Type httpType() { return http_type_; } @@ -307,8 +307,8 @@ class FakeUpstream : Logger::Loggable, public Network::Filt void cleanUp(); private: - FakeUpstream(Ssl::ServerContext* ssl_ctx, Network::ListenSocketPtr&& connection, - FakeHttpConnection::Type type); + FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, + Network::ListenSocketPtr&& connection, FakeHttpConnection::Type type); class FakeListener : public Network::ListenerConfig { public: @@ -318,7 +318,9 @@ class FakeUpstream : Logger::Loggable, public Network::Filt // Network::ListenerConfig Network::FilterChainFactory& filterChainFactory() override { return parent_; } Network::ListenSocket& socket() override { return *parent_.socket_; } - Ssl::ServerContext* defaultSslContext() override { return parent_.ssl_ctx_; } + Network::TransportSocketFactory& transportSocketFactory() override { + return *parent_.transport_socket_factory_; + } bool bindToPort() override { return true; } bool handOffRestoredDestinationConnections() const override { return false; } uint32_t perConnectionBufferLimitBytes() override { return 0; } @@ -332,7 +334,7 @@ class FakeUpstream : Logger::Loggable, public Network::Filt void threadRoutine(); - Ssl::ServerContext* ssl_ctx_{}; + Network::TransportSocketFactoryPtr transport_socket_factory_; Network::ListenSocketPtr socket_; ConditionalInitializer server_initialized_; // Guards any objects which can be altered both in the upstream thread and the diff --git a/test/integration/ssl_integration_test.cc b/test/integration/ssl_integration_test.cc index e4ffd3fdb1730..85cf15e961bb7 100644 --- a/test/integration/ssl_integration_test.cc +++ b/test/integration/ssl_integration_test.cc @@ -38,7 +38,6 @@ void SslIntegrationTest::initialize() { void SslIntegrationTest::TearDown() { test_server_.reset(); fake_upstreams_.clear(); - upstream_ssl_ctx_.reset(); client_ssl_ctx_plain_.reset(); client_ssl_ctx_alpn_.reset(); client_ssl_ctx_san_.reset(); @@ -47,20 +46,6 @@ void SslIntegrationTest::TearDown() { runtime_.reset(); } -ServerContextPtr SslIntegrationTest::createUpstreamSslContext() { - static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); - std::string json = R"EOF( -{ - "cert_chain_file": "{{ test_rundir }}/test/config/integration/certs/upstreamcert.pem", - "private_key_file": "{{ test_rundir }}/test/config/integration/certs/upstreamkey.pem" -} -)EOF"; - - Json::ObjectSharedPtr loader = TestEnvironment::jsonLoadFromString(json); - ServerContextConfigImpl cfg(*loader); - return context_manager_->createSslServerContext("", {}, *upstream_stats_store, cfg, true); -} - Network::ClientConnectionPtr SslIntegrationTest::makeSslClientConnection(bool alpn, bool san) { Network::Address::InstanceConstSharedPtr address = getSslAddress(version_, lookupPort("http")); if (alpn) { diff --git a/test/integration/ssl_integration_test.h b/test/integration/ssl_integration_test.h index aeabed964c466..755ef09dddaf5 100644 --- a/test/integration/ssl_integration_test.h +++ b/test/integration/ssl_integration_test.h @@ -26,13 +26,11 @@ class SslIntegrationTest : public HttpIntegrationTest, Network::ClientConnectionPtr makeSslConn() { return makeSslClientConnection(false, false); } Network::ClientConnectionPtr makeSslClientConnection(bool alpn, bool san); - ServerContextPtr createUpstreamSslContext(); void checkStats(); private: std::unique_ptr runtime_; std::unique_ptr context_manager_; - ServerContextPtr upstream_ssl_ctx_; Network::TransportSocketFactoryPtr client_ssl_ctx_plain_; Network::TransportSocketFactoryPtr client_ssl_ctx_alpn_; diff --git a/test/integration/xfcc_integration_test.cc b/test/integration/xfcc_integration_test.cc index 454c9a19530ac..8f954ea75e16f 100644 --- a/test/integration/xfcc_integration_test.cc +++ b/test/integration/xfcc_integration_test.cc @@ -30,7 +30,6 @@ void XfccIntegrationTest::TearDown() { client_mtls_ssl_ctx_.reset(); client_tls_ssl_ctx_.reset(); fake_upstreams_.clear(); - upstream_ssl_ctx_.reset(); context_manager_.reset(); runtime_.reset(); } @@ -64,7 +63,7 @@ Network::TransportSocketFactoryPtr XfccIntegrationTest::createClientSslContext(b new Ssl::ClientSslSocketFactory(cfg, *context_manager_, *client_stats_store)}; } -Ssl::ServerContextPtr XfccIntegrationTest::createUpstreamSslContext() { +Network::TransportSocketFactoryPtr XfccIntegrationTest::createUpstreamSslContext() { std::string json = R"EOF( { "cert_chain_file": "{{ test_rundir }}/test/config/integration/certs/upstreamcert.pem", @@ -74,8 +73,10 @@ Ssl::ServerContextPtr XfccIntegrationTest::createUpstreamSslContext() { Json::ObjectSharedPtr loader = TestEnvironment::jsonLoadFromString(json); Ssl::ServerContextConfigImpl cfg(*loader); - static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); - return context_manager_->createSslServerContext("", {}, *upstream_stats_store, cfg, true); + static Stats::Scope* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); + return std::make_unique(cfg, EMPTY_STRING, + std::vector{}, true, + *context_manager_, *upstream_stats_store); } Network::ClientConnectionPtr XfccIntegrationTest::makeClientConnection() { @@ -96,9 +97,8 @@ Network::ClientConnectionPtr XfccIntegrationTest::makeMtlsClientConnection() { } void XfccIntegrationTest::createUpstreams() { - upstream_ssl_ctx_ = createUpstreamSslContext(); fake_upstreams_.emplace_back( - new FakeUpstream(upstream_ssl_ctx_.get(), 0, FakeHttpConnection::Type::HTTP1, version_)); + new FakeUpstream(createUpstreamSslContext(), 0, FakeHttpConnection::Type::HTTP1, version_)); } void XfccIntegrationTest::initialize() { diff --git a/test/integration/xfcc_integration_test.h b/test/integration/xfcc_integration_test.h index a10dd205afe79..f27ebaf7eac0d 100644 --- a/test/integration/xfcc_integration_test.h +++ b/test/integration/xfcc_integration_test.h @@ -36,7 +36,7 @@ class XfccIntegrationTest : public HttpIntegrationTest, void TearDown() override; - Ssl::ServerContextPtr createUpstreamSslContext(); + Network::TransportSocketFactoryPtr createUpstreamSslContext(); Network::TransportSocketFactoryPtr createClientSslContext(bool mtls); Network::ClientConnectionPtr makeClientConnection(); Network::ClientConnectionPtr makeTlsClientConnection(); @@ -51,7 +51,7 @@ class XfccIntegrationTest : public HttpIntegrationTest, std::unique_ptr context_manager_; Network::TransportSocketFactoryPtr client_tls_ssl_ctx_; Network::TransportSocketFactoryPtr client_mtls_ssl_ctx_; - Ssl::ServerContextPtr upstream_ssl_ctx_; + Network::TransportSocketFactoryPtr upstream_ssl_ctx_; }; } // namespace Xfcc } // namespace Envoy diff --git a/test/mocks/event/mocks.h b/test/mocks/event/mocks.h index eb309700674e7..64d363815f4e7 100644 --- a/test/mocks/event/mocks.h +++ b/test/mocks/event/mocks.h @@ -29,9 +29,10 @@ class MockDispatcher : public Dispatcher { MockDispatcher(); ~MockDispatcher(); - Network::ConnectionPtr createServerConnection(Network::ConnectionSocketPtr&& socket, - Ssl::Context* ssl_ctx) override { - return Network::ConnectionPtr{createServerConnection_(socket.get(), ssl_ctx)}; + Network::ConnectionPtr + createServerConnection(Network::ConnectionSocketPtr&& socket, + Network::TransportSocketPtr&& transport_socket) override { + return Network::ConnectionPtr{createServerConnection_(socket.get(), transport_socket.get())}; } Network::ClientConnectionPtr @@ -75,7 +76,8 @@ class MockDispatcher : public Dispatcher { // Event::Dispatcher MOCK_METHOD0(clearDeferredDeleteList, void()); MOCK_METHOD2(createServerConnection_, - Network::Connection*(Network::ConnectionSocket* socket, Ssl::Context* ssl_ctx)); + Network::Connection*(Network::ConnectionSocket* socket, + Network::TransportSocket* transport_socket)); MOCK_METHOD4( createClientConnection_, Network::ClientConnection*(Network::Address::InstanceConstSharedPtr address, diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 1e618982894cd..c086222c3972b 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -297,7 +297,7 @@ class MockListenerConfig : public ListenerConfig { MOCK_METHOD0(filterChainFactory, FilterChainFactory&()); MOCK_METHOD0(socket, ListenSocket&()); - MOCK_METHOD0(defaultSslContext, Ssl::ServerContext*()); + MOCK_METHOD0(transportSocketFactory, TransportSocketFactory&()); MOCK_METHOD0(bindToPort, bool()); MOCK_CONST_METHOD0(handOffRestoredDestinationConnections, bool()); MOCK_METHOD0(perConnectionBufferLimitBytes, uint32_t()); diff --git a/test/server/BUILD b/test/server/BUILD index 1096fde1d094f..5d570330215af 100644 --- a/test/server/BUILD +++ b/test/server/BUILD @@ -129,6 +129,8 @@ envoy_cc_test( "//source/server:listener_manager_lib", "//source/server/config/listener:original_dst_lib", "//source/server/config/network:http_connection_manager_lib", + "//source/server/config/network:raw_buffer_socket_lib", + "//source/server/config/network:ssl_socket_lib", "//test/mocks/server:server_mocks", "//test/test_common:environment_lib", ], diff --git a/test/server/connection_handler_test.cc b/test/server/connection_handler_test.cc index 44710fc8c9bfe..19cb7ba90e82f 100644 --- a/test/server/connection_handler_test.cc +++ b/test/server/connection_handler_test.cc @@ -1,5 +1,6 @@ #include "common/common/utility.h" #include "common/network/address_impl.h" +#include "common/network/raw_buffer_socket.h" #include "common/stats/stats_impl.h" #include "server/connection_handler_impl.h" @@ -36,7 +37,9 @@ class ConnectionHandlerTest : public testing::Test, protected Logger::LoggableaddOrUpdateListener(parseListenerFromJson(json), true); - EXPECT_NE(nullptr, manager_->listeners().back().get().defaultSslContext()); + EXPECT_TRUE( + manager_->listeners().back().get().transportSocketFactory().implementsSecureTransport()); } TEST_F(ListenerManagerImplWithRealFiltersTest, BadListenerConfig) { diff --git a/test/test_common/network_utility.cc b/test/test_common/network_utility.cc index 4598d9e0d71d9..5c99cf00ceaeb 100644 --- a/test/test_common/network_utility.cc +++ b/test/test_common/network_utility.cc @@ -163,6 +163,9 @@ std::pair bindFreeLoopbackPort(Address::Ip TransportSocketPtr createRawBufferSocket() { return std::make_unique(); } +TransportSocketFactoryPtr createRawBufferSocketFactory() { + return std::make_unique(); +}; } // namespace Test } // namespace Network } // namespace Envoy diff --git a/test/test_common/network_utility.h b/test/test_common/network_utility.h index aa714a8fb13b6..d59355afc3374 100644 --- a/test/test_common/network_utility.h +++ b/test/test_common/network_utility.h @@ -97,8 +97,19 @@ bool supportsIpVersion(const Address::IpVersion version); std::pair bindFreeLoopbackPort(Address::IpVersion version, Address::SocketType type); +/** + * Create a transport socket for testing purposes. + * @return TransportSocketPtr the transport socket factory to use with a connection. + */ TransportSocketPtr createRawBufferSocket(); +/** + * Create a transport socket factory for testing purposes. + * @return TransportSocketFactoryPtr the transport socket factory to use with a cluster or a + * listener. + */ +TransportSocketFactoryPtr createRawBufferSocketFactory(); + } // namespace Test } // namespace Network } // namespace Envoy