diff --git a/envoy/network/filter.h b/envoy/network/filter.h index 9911ebb04ac64..ec6d14f9042a7 100644 --- a/envoy/network/filter.h +++ b/envoy/network/filter.h @@ -389,7 +389,7 @@ class FilterChain { * @return const TransportSocketFactory& a transport socket factory to be used by the new * connection. */ - virtual const TransportSocketFactory& transportSocketFactory() const PURE; + virtual const DownstreamTransportSocketFactory& transportSocketFactory() const PURE; /** * @return std::chrono::milliseconds the amount of time to wait for the transport socket to report diff --git a/envoy/network/transport_socket.h b/envoy/network/transport_socket.h index ca892ed079d53..a78f8f2cf2e2c 100644 --- a/envoy/network/transport_socket.h +++ b/envoy/network/transport_socket.h @@ -16,7 +16,6 @@ namespace Envoy { namespace Network { -class TransportSocketFactory; class Connection; enum class ConnectionEvent; @@ -239,20 +238,28 @@ class TransportSocketOptions { using TransportSocketOptionsConstSharedPtr = std::shared_ptr; /** - * A factory for creating transport socket. It will be associated to filter chains and clusters. - */ -class TransportSocketFactory { + * A factory for creating transport sockets. + **/ +class TransportSocketFactoryBase { public: - virtual ~TransportSocketFactory() = default; + virtual ~TransportSocketFactoryBase() = default; /** * @return bool whether the transport socket implements secure transport. */ virtual bool implementsSecureTransport() const PURE; +}; + +/** + * A factory for creating upstream transport sockets. It will be associated to clusters. + */ +class UpstreamTransportSocketFactory : public virtual TransportSocketFactoryBase { +public: + ~UpstreamTransportSocketFactory() override = default; /** * @param options for creating the transport socket - * @return Network::TransportSocketPtr a transport socket to be passed to connection. + * @return Network::TransportSocketPtr a transport socket to be passed to client connection. */ virtual TransportSocketPtr createTransportSocket(TransportSocketOptionsConstSharedPtr options) const PURE; @@ -280,7 +287,21 @@ class TransportSocketFactory { TransportSocketOptionsConstSharedPtr options) const PURE; }; -using TransportSocketFactoryPtr = std::unique_ptr; +/** + * A factory for creating downstream transport sockets. It will be associated to listeners. + */ +class DownstreamTransportSocketFactory : public virtual TransportSocketFactoryBase { +public: + ~DownstreamTransportSocketFactory() override = default; + + /** + * @return Network::TransportSocketPtr a transport socket to be passed to server connection. + */ + virtual TransportSocketPtr createDownstreamTransportSocket() const PURE; +}; + +using UpstreamTransportSocketFactoryPtr = std::unique_ptr; +using DownstreamTransportSocketFactoryPtr = std::unique_ptr; } // namespace Network } // namespace Envoy diff --git a/envoy/server/transport_socket_config.h b/envoy/server/transport_socket_config.h index 2a16bc4409712..9b3d5e3ad1c47 100644 --- a/envoy/server/transport_socket_config.h +++ b/envoy/server/transport_socket_config.h @@ -124,13 +124,13 @@ class UpstreamTransportSocketConfigFactory : public virtual TransportSocketConfi * @param config const Protobuf::Message& supplies the config message for the transport socket * implementation. * @param context TransportSocketFactoryContext& supplies the transport socket's context. - * @return Network::TransportSocketFactoryPtr the transport socket factory instance. The returned - * TransportSocketFactoryPtr should not be nullptr. + * @return Network::UpstreamTransportSocketFactoryPtr the transport socket factory instance. The + * returned TransportSocketFactoryPtr should not be nullptr. * * @throw EnvoyException if the implementation is unable to produce a factory with the provided * parameters. */ - virtual Network::TransportSocketFactoryPtr + virtual Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, TransportSocketFactoryContext& context) PURE; @@ -150,13 +150,13 @@ class DownstreamTransportSocketConfigFactory : public virtual TransportSocketCon * @param config const Protobuf::Message& supplies the config message for the transport socket * implementation. * @param context TransportSocketFactoryContext& supplies the transport socket's context. - * @return Network::TransportSocketFactoryPtr the transport socket factory instance. The returned - * TransportSocketFactoryPtr should not be nullptr. + * @return Network::DownstreamTransportSocketFactoryPtr the transport socket factory instance. The + * returned TransportSocketFactoryPtr should not be nullptr. * * @throw EnvoyException if the implementation is unable to produce a factory with the provided * parameters. */ - virtual Network::TransportSocketFactoryPtr + virtual Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, TransportSocketFactoryContext& context, const std::vector& server_names) PURE; diff --git a/envoy/upstream/host_description.h b/envoy/upstream/host_description.h index 016de81280338..9231747331df8 100644 --- a/envoy/upstream/host_description.h +++ b/envoy/upstream/host_description.h @@ -148,7 +148,7 @@ class HostDescription { /** * @return the transport socket factory responsible for this host. */ - virtual Network::TransportSocketFactory& transportSocketFactory() const PURE; + virtual Network::UpstreamTransportSocketFactory& transportSocketFactory() const PURE; /** * @return the address used to connect to the host. @@ -220,10 +220,10 @@ struct TransportSocketMatchStats { class TransportSocketMatcher { public: struct MatchData { - MatchData(Network::TransportSocketFactory& factory, TransportSocketMatchStats& stats, + MatchData(Network::UpstreamTransportSocketFactory& factory, TransportSocketMatchStats& stats, std::string name) : factory_(factory), stats_(stats), name_(std::move(name)) {} - Network::TransportSocketFactory& factory_; + Network::UpstreamTransportSocketFactory& factory_; TransportSocketMatchStats& stats_; std::string name_; }; diff --git a/source/common/http/conn_pool_grid.cc b/source/common/http/conn_pool_grid.cc index 5a88fef59c7b6..6c45e48eca39d 100644 --- a/source/common/http/conn_pool_grid.cc +++ b/source/common/http/conn_pool_grid.cc @@ -19,7 +19,7 @@ absl::string_view describePool(const ConnectionPool::Instance& pool) { static constexpr uint32_t kDefaultTimeoutMs = 300; std::string getSni(const Network::TransportSocketOptionsConstSharedPtr& options, - Network::TransportSocketFactory& transport_socket_factory) { + Network::UpstreamTransportSocketFactory& transport_socket_factory) { if (options && options->serverNameOverride().has_value()) { return options->serverNameOverride().value(); } diff --git a/source/common/http/http3/conn_pool.cc b/source/common/http/http3/conn_pool.cc index 652514d4de9b0..5f78594c538af 100644 --- a/source/common/http/http3/conn_pool.cc +++ b/source/common/http/http3/conn_pool.cc @@ -23,7 +23,7 @@ uint32_t getMaxStreams(const Upstream::ClusterInfo& cluster) { } const Envoy::Ssl::ClientContextConfig& -getConfig(Network::TransportSocketFactory& transport_socket_factory) { +getConfig(Network::UpstreamTransportSocketFactory& transport_socket_factory) { return dynamic_cast(transport_socket_factory) .clientContextConfig(); } diff --git a/source/common/network/happy_eyeballs_connection_impl.cc b/source/common/network/happy_eyeballs_connection_impl.cc index c6dc14da14cbb..626d3c2478137 100644 --- a/source/common/network/happy_eyeballs_connection_impl.cc +++ b/source/common/network/happy_eyeballs_connection_impl.cc @@ -7,7 +7,7 @@ namespace Network { HappyEyeballsConnectionImpl::HappyEyeballsConnectionImpl( Event::Dispatcher& dispatcher, const std::vector& address_list, - Address::InstanceConstSharedPtr source_address, TransportSocketFactory& socket_factory, + Address::InstanceConstSharedPtr source_address, UpstreamTransportSocketFactory& socket_factory, TransportSocketOptionsConstSharedPtr transport_socket_options, const ConnectionSocket::OptionsSharedPtr options) : id_(ConnectionImpl::next_global_id_++), dispatcher_(dispatcher), diff --git a/source/common/network/happy_eyeballs_connection_impl.h b/source/common/network/happy_eyeballs_connection_impl.h index e314927f7222b..38085c51453d7 100644 --- a/source/common/network/happy_eyeballs_connection_impl.h +++ b/source/common/network/happy_eyeballs_connection_impl.h @@ -38,7 +38,7 @@ class HappyEyeballsConnectionImpl : public ClientConnection, HappyEyeballsConnectionImpl(Event::Dispatcher& dispatcher, const std::vector& address_list, Address::InstanceConstSharedPtr source_address, - TransportSocketFactory& socket_factory, + UpstreamTransportSocketFactory& socket_factory, TransportSocketOptionsConstSharedPtr transport_socket_options, const ConnectionSocket::OptionsSharedPtr options); @@ -197,7 +197,7 @@ class HappyEyeballsConnectionImpl : public ClientConnection, // State which is needed to construct a new connection. struct ConnectionConstructionState { Address::InstanceConstSharedPtr source_address_; - TransportSocketFactory& socket_factory_; + UpstreamTransportSocketFactory& socket_factory_; TransportSocketOptionsConstSharedPtr transport_socket_options_; const ConnectionSocket::OptionsSharedPtr options_; }; diff --git a/source/common/network/raw_buffer_socket.cc b/source/common/network/raw_buffer_socket.cc index a35dd49e70938..0d6dfa28c0c3c 100644 --- a/source/common/network/raw_buffer_socket.cc +++ b/source/common/network/raw_buffer_socket.cc @@ -91,6 +91,10 @@ RawBufferSocketFactory::createTransportSocket(TransportSocketOptionsConstSharedP return std::make_unique(); } +TransportSocketPtr RawBufferSocketFactory::createDownstreamTransportSocket() const { + return std::make_unique(); +} + bool RawBufferSocketFactory::implementsSecureTransport() const { return false; } } // namespace Network diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index c55bc06ff7337..fa7cc9ea2116b 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -33,13 +33,16 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggable& key, - TransportSocketOptionsConstSharedPtr options) const { +void CommonUpstreamTransportSocketFactory::hashKey( + std::vector& key, TransportSocketOptionsConstSharedPtr options) const { if (!options) { return; } diff --git a/source/common/network/transport_socket_options_impl.h b/source/common/network/transport_socket_options_impl.h index 34029dcbe7e46..6b94825dd522b 100644 --- a/source/common/network/transport_socket_options_impl.h +++ b/source/common/network/transport_socket_options_impl.h @@ -92,7 +92,7 @@ class TransportSocketOptionsUtility { fromFilterState(const StreamInfo::FilterStateSharedPtr& stream_info); }; -class CommonTransportSocketFactory : public TransportSocketFactory { +class CommonUpstreamTransportSocketFactory : public UpstreamTransportSocketFactory { public: /** * Compute the generic hash key from the transport socket options. diff --git a/source/common/quic/envoy_quic_crypto_stream_factory.h b/source/common/quic/envoy_quic_crypto_stream_factory.h index cbfcbfebe6781..24a36b5caebd0 100644 --- a/source/common/quic/envoy_quic_crypto_stream_factory.h +++ b/source/common/quic/envoy_quic_crypto_stream_factory.h @@ -22,7 +22,7 @@ class EnvoyQuicCryptoServerStreamFactoryInterface : public Config::TypedFactory const quic::QuicCryptoServerConfig* crypto_config, quic::QuicCompressedCertsCache* compressed_certs_cache, quic::QuicSession* session, quic::QuicCryptoServerStreamBase::Helper* helper, - OptRef transport_socket_factory, + OptRef transport_socket_factory, Event::Dispatcher& dispatcher) PURE; }; diff --git a/source/common/quic/quic_transport_socket_factory.cc b/source/common/quic/quic_transport_socket_factory.cc index ea38879ab0229..94428d6fa7796 100644 --- a/source/common/quic/quic_transport_socket_factory.cc +++ b/source/common/quic/quic_transport_socket_factory.cc @@ -12,7 +12,7 @@ namespace Envoy { namespace Quic { -Network::TransportSocketFactoryPtr +Network::DownstreamTransportSocketFactoryPtr QuicServerTransportSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& /*server_names*/) { @@ -33,7 +33,7 @@ ProtobufTypes::MessagePtr QuicServerTransportSocketConfigFactory::createEmptyCon envoy::extensions::transport_sockets::quic::v3::QuicDownstreamTransport>(); } -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr QuicClientTransportSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) { diff --git a/source/common/quic/quic_transport_socket_factory.h b/source/common/quic/quic_transport_socket_factory.h index 17ddde660b818..ef974dd108227 100644 --- a/source/common/quic/quic_transport_socket_factory.h +++ b/source/common/quic/quic_transport_socket_factory.h @@ -37,23 +37,16 @@ QuicTransportSocketFactoryStats generateStats(Stats::Scope& store, const std::st // socket for QUIC in current implementation. This factory doesn't provides a // transport socket, instead, its derived class provides TLS context config for // server and client. -class QuicTransportSocketFactoryBase : public Network::CommonTransportSocketFactory, - protected Logger::Loggable { +class QuicTransportSocketFactoryBase : protected Logger::Loggable { public: QuicTransportSocketFactoryBase(Stats::Scope& store, const std::string& perspective) : stats_(generateStats(store, perspective)) {} + virtual ~QuicTransportSocketFactoryBase() = default; + // To be called right after construction. virtual void initialize() PURE; - // Network::TransportSocketFactory - Network::TransportSocketPtr - createTransportSocket(Network::TransportSocketOptionsConstSharedPtr /*options*/) const override { - PANIC("not implemented"); - } - bool implementsSecureTransport() const override { return true; } - bool supportsAlpn() const override { return true; } - protected: virtual void onSecretUpdated() PURE; QuicTransportSocketFactoryStats stats_; @@ -61,13 +54,20 @@ class QuicTransportSocketFactoryBase : public Network::CommonTransportSocketFact // TODO(danzh): when implement ProofSource, examine of it's necessary to // differentiate server and client side context config. -class QuicServerTransportSocketFactory : public QuicTransportSocketFactoryBase { +class QuicServerTransportSocketFactory : public Network::DownstreamTransportSocketFactory, + public QuicTransportSocketFactoryBase { public: QuicServerTransportSocketFactory(bool enable_early_data, Stats::Scope& store, Ssl::ServerContextConfigPtr config) : QuicTransportSocketFactoryBase(store, "server"), config_(std::move(config)), enable_early_data_(enable_early_data) {} + // Network::DownstreamTransportSocketFactory + Network::TransportSocketPtr createDownstreamTransportSocket() const override { + PANIC("not implemented"); + } + bool implementsSecureTransport() const override { return true; } + void initialize() override { config_->setSecretUpdateCallback([this]() { // The callback also updates config_ with the new secret. @@ -87,7 +87,6 @@ class QuicServerTransportSocketFactory : public QuicTransportSocketFactoryBase { } bool earlyDataEnabled() const { return enable_early_data_; } - absl::string_view defaultServerNameIndication() const override { return ""; } protected: void onSecretUpdated() override { stats_.context_config_update_by_sds_.inc(); } @@ -97,13 +96,16 @@ class QuicServerTransportSocketFactory : public QuicTransportSocketFactoryBase { bool enable_early_data_; }; -class QuicClientTransportSocketFactory : public QuicTransportSocketFactoryBase { +class QuicClientTransportSocketFactory : public Network::CommonUpstreamTransportSocketFactory, + public QuicTransportSocketFactoryBase { public: QuicClientTransportSocketFactory( Ssl::ClientContextConfigPtr config, Server::Configuration::TransportSocketFactoryContext& factory_context); void initialize() override {} + bool implementsSecureTransport() const override { return true; } + bool supportsAlpn() const override { return true; } absl::string_view defaultServerNameIndication() const override { return clientContextConfig().serverNameIndication(); } @@ -157,7 +159,7 @@ class QuicServerTransportSocketConfigFactory public Server::Configuration::DownstreamTransportSocketConfigFactory { public: // Server::Configuration::DownstreamTransportSocketConfigFactory - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) override; @@ -173,7 +175,7 @@ class QuicClientTransportSocketConfigFactory public Server::Configuration::UpstreamTransportSocketConfigFactory { public: // Server::Configuration::UpstreamTransportSocketConfigFactory - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override; diff --git a/source/common/upstream/health_discovery_service.cc b/source/common/upstream/health_discovery_service.cc index 183d2669711ea..c83c35ef39acc 100644 --- a/source/common/upstream/health_discovery_service.cc +++ b/source/common/upstream/health_discovery_service.cc @@ -532,7 +532,7 @@ ProdClusterInfoFactory::createClusterInfo(const CreateClusterInfoParams& params) params.validation_visitor_, params.api_, params.options_, params.access_log_manager_); // TODO(JimmyCYJ): Support SDS for HDS cluster. - Network::TransportSocketFactoryPtr socket_factory = + Network::UpstreamTransportSocketFactoryPtr socket_factory = Upstream::createTransportSocketFactory(params.cluster_, factory_context); auto socket_matcher = std::make_unique( params.cluster_.transport_socket_matches(), factory_context, socket_factory, *scope); diff --git a/source/common/upstream/logical_host.h b/source/common/upstream/logical_host.h index 4620d06737d64..3cbebf9ab4210 100644 --- a/source/common/upstream/logical_host.h +++ b/source/common/upstream/logical_host.h @@ -90,7 +90,7 @@ class RealHostDescription : public HostDescription { MetadataConstSharedPtr metadata() const override { return logical_host_->metadata(); } void metadata(MetadataConstSharedPtr) override {} - Network::TransportSocketFactory& transportSocketFactory() const override { + Network::UpstreamTransportSocketFactory& transportSocketFactory() const override { return logical_host_->transportSocketFactory(); } const ClusterInfo& cluster() const override { return logical_host_->cluster(); } diff --git a/source/common/upstream/transport_socket_match_impl.cc b/source/common/upstream/transport_socket_match_impl.cc index 57e8f140a94b2..63df66e2c2c24 100644 --- a/source/common/upstream/transport_socket_match_impl.cc +++ b/source/common/upstream/transport_socket_match_impl.cc @@ -13,7 +13,7 @@ TransportSocketMatcherImpl::TransportSocketMatcherImpl( const Protobuf::RepeatedPtrField& socket_matches, Server::Configuration::TransportSocketFactoryContext& factory_context, - Network::TransportSocketFactoryPtr& default_factory, Stats::Scope& stats_scope) + Network::UpstreamTransportSocketFactoryPtr& default_factory, Stats::Scope& stats_scope) : stats_scope_(stats_scope), default_match_("default", std::move(default_factory), generateStats("default")) { for (const auto& socket_match : socket_matches) { diff --git a/source/common/upstream/transport_socket_match_impl.h b/source/common/upstream/transport_socket_match_impl.h index 6d9f665c8b827..ad1933e1638ad 100644 --- a/source/common/upstream/transport_socket_match_impl.h +++ b/source/common/upstream/transport_socket_match_impl.h @@ -23,11 +23,11 @@ class TransportSocketMatcherImpl : public Logger::Loggable public TransportSocketMatcher { public: struct FactoryMatch { - FactoryMatch(std::string match_name, Network::TransportSocketFactoryPtr socket_factory, + FactoryMatch(std::string match_name, Network::UpstreamTransportSocketFactoryPtr socket_factory, TransportSocketMatchStats match_stats) : name(std::move(match_name)), factory(std::move(socket_factory)), stats(match_stats) {} const std::string name; - Network::TransportSocketFactoryPtr factory; + Network::UpstreamTransportSocketFactoryPtr factory; Config::Metadata::LabelSet label_set; mutable TransportSocketMatchStats stats; }; @@ -36,7 +36,7 @@ class TransportSocketMatcherImpl : public Logger::Loggable const Protobuf::RepeatedPtrField& socket_matches, Server::Configuration::TransportSocketFactoryContext& factory_context, - Network::TransportSocketFactoryPtr& default_factory, Stats::Scope& stats_scope); + Network::UpstreamTransportSocketFactoryPtr& default_factory, Stats::Scope& stats_scope); MatchData resolve(const envoy::config::core::v3::Metadata* metadata) const override; diff --git a/source/common/upstream/upstream_impl.cc b/source/common/upstream/upstream_impl.cc index cedb77f32d8ec..9de76041481d9 100644 --- a/source/common/upstream/upstream_impl.cc +++ b/source/common/upstream/upstream_impl.cc @@ -265,7 +265,7 @@ HostDescriptionImpl::HostDescriptionImpl( : Network::Utility::getAddressWithPort(*dest_address, health_check_config.port_value()); } -Network::TransportSocketFactory& HostDescriptionImpl::resolveTransportSocketFactory( +Network::UpstreamTransportSocketFactory& HostDescriptionImpl::resolveTransportSocketFactory( const Network::Address::InstanceConstSharedPtr& dest_address, const envoy::config::core::v3::Metadata* metadata) const { auto match = cluster_->transportSocketMatcher().resolve(metadata); @@ -307,7 +307,7 @@ Host::CreateConnectionData HostImpl::createHealthCheckConnection( Network::TransportSocketOptionsConstSharedPtr transport_socket_options, const envoy::config::core::v3::Metadata* metadata) const { - Network::TransportSocketFactory& factory = + Network::UpstreamTransportSocketFactory& factory = (metadata != nullptr) ? resolveTransportSocketFactory(healthCheckAddress(), metadata) : transportSocketFactory(); return {createConnection(dispatcher, cluster(), healthCheckAddress(), {}, factory, nullptr, @@ -319,7 +319,7 @@ Network::ClientConnectionPtr HostImpl::createConnection( Event::Dispatcher& dispatcher, const ClusterInfo& cluster, const Network::Address::InstanceConstSharedPtr& address, const std::vector& address_list, - Network::TransportSocketFactory& socket_factory, + Network::UpstreamTransportSocketFactory& socket_factory, const Network::ConnectionSocket::OptionsSharedPtr& options, Network::TransportSocketOptionsConstSharedPtr transport_socket_options) { Network::ConnectionSocket::OptionsSharedPtr connection_options; @@ -1012,7 +1012,7 @@ ClusterInfoImpl::extensionProtocolOptions(const std::string& name) const { return nullptr; } -Network::TransportSocketFactoryPtr createTransportSocketFactory( +Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const envoy::config::cluster::v3::Cluster& config, Server::Configuration::TransportSocketFactoryContext& factory_context) { // If the cluster config doesn't have a transport socket configured, override with the default diff --git a/source/common/upstream/upstream_impl.h b/source/common/upstream/upstream_impl.h index 2dc5065a18f5b..555683b3ff719 100644 --- a/source/common/upstream/upstream_impl.h +++ b/source/common/upstream/upstream_impl.h @@ -98,7 +98,7 @@ class HostDescriptionImpl : virtual public HostDescription, const envoy::config::endpoint::v3::Endpoint::HealthCheckConfig& health_check_config, uint32_t priority, TimeSource& time_source); - Network::TransportSocketFactory& transportSocketFactory() const override { + Network::UpstreamTransportSocketFactory& transportSocketFactory() const override { absl::ReaderMutexLock lock(&metadata_mutex_); return socket_factory_; } @@ -170,7 +170,7 @@ class HostDescriptionImpl : virtual public HostDescription, } uint32_t priority() const override { return priority_; } void priority(uint32_t priority) override { priority_ = priority; } - Network::TransportSocketFactory& + Network::UpstreamTransportSocketFactory& resolveTransportSocketFactory(const Network::Address::InstanceConstSharedPtr& dest_address, const envoy::config::core::v3::Metadata* metadata) const; MonotonicTime creationTime() const override { return creation_time_; } @@ -211,7 +211,7 @@ class HostDescriptionImpl : virtual public HostDescription, Outlier::DetectorHostMonitorPtr outlier_detector_; HealthCheckHostMonitorPtr health_checker_; std::atomic priority_; - std::reference_wrapper + std::reference_wrapper socket_factory_ ABSL_GUARDED_BY(metadata_mutex_); const MonotonicTime creation_time_; }; @@ -293,7 +293,7 @@ class HostImpl : public HostDescriptionImpl, createConnection(Event::Dispatcher& dispatcher, const ClusterInfo& cluster, const Network::Address::InstanceConstSharedPtr& address, const std::vector& address_list, - Network::TransportSocketFactory& socket_factory, + Network::UpstreamTransportSocketFactory& socket_factory, const Network::ConnectionSocket::OptionsSharedPtr& options, Network::TransportSocketOptionsConstSharedPtr transport_socket_options); @@ -845,11 +845,11 @@ class ClusterInfoImpl : public ClusterInfo, }; /** - * Function that creates a Network::TransportSocketFactoryPtr + * Function that creates a Network::UpstreamTransportSocketFactoryPtr * given a cluster configuration and transport socket factory * context. */ -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory(const envoy::config::cluster::v3::Cluster& config, Server::Configuration::TransportSocketFactoryContext& factory_context); diff --git a/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.cc b/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.cc index ad17d8512605f..949bc5fafdd7a 100644 --- a/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.cc +++ b/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.cc @@ -10,7 +10,7 @@ EnvoyQuicCryptoServerStreamFactoryImpl::createEnvoyQuicCryptoServerStream( quic::QuicCryptoServerStreamBase::Helper* helper, // Though this extension doesn't use the two parameters below, they might be used by // downstreams. Do not remove them. - OptRef /*transport_socket_factory*/, + OptRef /*transport_socket_factory*/, Envoy::Event::Dispatcher& /*dispatcher*/) { return quic::CreateCryptoServerStream(crypto_config, compressed_certs_cache, session, helper); } diff --git a/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.h b/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.h index 10bf47cf5e8ba..20fbd9fd44f84 100644 --- a/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.h +++ b/source/extensions/quic/crypto_stream/envoy_quic_crypto_server_stream.h @@ -18,7 +18,7 @@ class EnvoyQuicCryptoServerStreamFactoryImpl : public EnvoyQuicCryptoServerStrea const quic::QuicCryptoServerConfig* crypto_config, quic::QuicCompressedCertsCache* compressed_certs_cache, quic::QuicSession* session, quic::QuicCryptoServerStreamBase::Helper* helper, - OptRef transport_socket_factory, + OptRef transport_socket_factory, Envoy::Event::Dispatcher& dispatcher) override; }; diff --git a/source/extensions/transport_sockets/alts/config.cc b/source/extensions/transport_sockets/alts/config.cc index 85d2ff354d4ab..7c3d18d27d7ab 100644 --- a/source/extensions/transport_sockets/alts/config.cc +++ b/source/extensions/transport_sockets/alts/config.cc @@ -91,7 +91,8 @@ class AltsSharedState : public Singleton::Instance { SINGLETON_MANAGER_REGISTRATION(alts_shared_state); -Network::TransportSocketFactoryPtr createTransportSocketFactoryHelper( +template +TransportSocketFactoryPtr createTransportSocketFactoryHelper( const Protobuf::Message& message, bool is_upstream, Server::Configuration::TransportSocketFactoryContext& factory_ctxt) { // A reference to this is held in the factory closure to keep the singleton @@ -147,19 +148,21 @@ ProtobufTypes::MessagePtr AltsTransportSocketConfigFactory::createEmptyConfigPro return std::make_unique(); } -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr UpstreamAltsTransportSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& factory_ctxt) { - return createTransportSocketFactoryHelper(message, /* is_upstream */ true, factory_ctxt); + return createTransportSocketFactoryHelper( + message, /* is_upstream */ true, factory_ctxt); } -Network::TransportSocketFactoryPtr +Network::DownstreamTransportSocketFactoryPtr DownstreamAltsTransportSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& factory_ctxt, const std::vector&) { - return createTransportSocketFactoryHelper(message, /* is_upstream */ false, factory_ctxt); + return createTransportSocketFactoryHelper( + message, /* is_upstream */ false, factory_ctxt); } REGISTER_FACTORY(UpstreamAltsTransportSocketConfigFactory, diff --git a/source/extensions/transport_sockets/alts/config.h b/source/extensions/transport_sockets/alts/config.h index 2a1a9d40b22d1..ce1df9853886d 100644 --- a/source/extensions/transport_sockets/alts/config.h +++ b/source/extensions/transport_sockets/alts/config.h @@ -19,7 +19,7 @@ class UpstreamAltsTransportSocketConfigFactory : public AltsTransportSocketConfigFactory, public Server::Configuration::UpstreamTransportSocketConfigFactory { public: - Network::TransportSocketFactoryPtr + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&) override; }; @@ -28,7 +28,7 @@ class DownstreamAltsTransportSocketConfigFactory : public AltsTransportSocketConfigFactory, public Server::Configuration::DownstreamTransportSocketConfigFactory { public: - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&, const std::vector&) override; diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 5689fd7f8d226..9dced4f35230d 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -385,6 +385,10 @@ TsiSocketFactory::createTransportSocket(Network::TransportSocketOptionsConstShar return std::make_unique(handshaker_factory_, handshake_validator_); } +Network::TransportSocketPtr TsiSocketFactory::createDownstreamTransportSocket() const { + return std::make_unique(handshaker_factory_, handshake_validator_); +} + } // namespace Alts } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index 3dbf5f6c6472d..a95e087debd64 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -128,9 +128,10 @@ class TsiSocket : public Network::TransportSocket, }; /** - * An implementation of Network::TransportSocketFactory for TsiSocket + * An implementation of Network::UpstreamTransportSocketFactory for TsiSocket */ -class TsiSocketFactory : public Network::CommonTransportSocketFactory { +class TsiSocketFactory : public Network::DownstreamTransportSocketFactory, + public Network::CommonUpstreamTransportSocketFactory { public: TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); @@ -140,6 +141,8 @@ class TsiSocketFactory : public Network::CommonTransportSocketFactory { Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; + Network::TransportSocketPtr createDownstreamTransportSocket() const override; + private: HandshakerFactory handshaker_factory_; HandshakeValidator handshake_validator_; diff --git a/source/extensions/transport_sockets/common/passthrough.h b/source/extensions/transport_sockets/common/passthrough.h index dd56ab9aef2b3..669d8f642157a 100644 --- a/source/extensions/transport_sockets/common/passthrough.h +++ b/source/extensions/transport_sockets/common/passthrough.h @@ -10,9 +10,9 @@ namespace Envoy { namespace Extensions { namespace TransportSockets { -class PassthroughFactory : public Network::CommonTransportSocketFactory { +class PassthroughFactory : public Network::CommonUpstreamTransportSocketFactory { public: - PassthroughFactory(Network::TransportSocketFactoryPtr&& transport_socket_factory) + PassthroughFactory(Network::UpstreamTransportSocketFactoryPtr&& transport_socket_factory) : transport_socket_factory_(std::move(transport_socket_factory)) { ASSERT(transport_socket_factory_ != nullptr); } @@ -27,7 +27,24 @@ class PassthroughFactory : public Network::CommonTransportSocketFactory { protected: // The wrapped factory. - Network::TransportSocketFactoryPtr transport_socket_factory_; + Network::UpstreamTransportSocketFactoryPtr transport_socket_factory_; +}; + +class DownstreamPassthroughFactory : public Network::DownstreamTransportSocketFactory { +public: + DownstreamPassthroughFactory( + Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory) + : transport_socket_factory_(std::move(transport_socket_factory)) { + ASSERT(transport_socket_factory_ != nullptr); + } + + bool implementsSecureTransport() const override { + return transport_socket_factory_->implementsSecureTransport(); + } + +protected: + // The wrapped factory. + Network::DownstreamTransportSocketFactoryPtr transport_socket_factory_; }; class PassthroughSocket : public Network::TransportSocket { diff --git a/source/extensions/transport_sockets/proxy_protocol/config.cc b/source/extensions/transport_sockets/proxy_protocol/config.cc index 860fdb637c32a..9e62bc11fb1dd 100644 --- a/source/extensions/transport_sockets/proxy_protocol/config.cc +++ b/source/extensions/transport_sockets/proxy_protocol/config.cc @@ -12,7 +12,7 @@ namespace Extensions { namespace TransportSockets { namespace ProxyProtocol { -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr UpstreamProxyProtocolSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context) { diff --git a/source/extensions/transport_sockets/proxy_protocol/config.h b/source/extensions/transport_sockets/proxy_protocol/config.h index b5b4bbe7bdb9e..f40a938cc5303 100644 --- a/source/extensions/transport_sockets/proxy_protocol/config.h +++ b/source/extensions/transport_sockets/proxy_protocol/config.h @@ -16,7 +16,7 @@ class UpstreamProxyProtocolSocketConfigFactory public: std::string name() const override { return "envoy.transport_sockets.upstream_proxy_protocol"; } ProtobufTypes::MessagePtr createEmptyConfigProto() override; - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override; }; diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc index 7f3843a8f4aab..29cc86aa872ae 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc @@ -108,7 +108,7 @@ void UpstreamProxyProtocolSocket::onConnected() { } UpstreamProxyProtocolSocketFactory::UpstreamProxyProtocolSocketFactory( - Network::TransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config) + Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config) : PassthroughFactory(std::move(transport_socket_factory)), config_(config) {} Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportSocket( diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h index eb35d93d10fdd..f91f11eb11014 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h @@ -41,10 +41,11 @@ class UpstreamProxyProtocolSocket : public TransportSockets::PassthroughSocket, class UpstreamProxyProtocolSocketFactory : public PassthroughFactory { public: - UpstreamProxyProtocolSocketFactory(Network::TransportSocketFactoryPtr transport_socket_factory, - ProxyProtocolConfig config); + UpstreamProxyProtocolSocketFactory( + Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, + ProxyProtocolConfig config); - // Network::TransportSocketFactory + // Network::UpstreamTransportSocketFactory Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; void hashKey(std::vector& key, diff --git a/source/extensions/transport_sockets/raw_buffer/config.cc b/source/extensions/transport_sockets/raw_buffer/config.cc index 04dc896635301..6c2f06de713d4 100644 --- a/source/extensions/transport_sockets/raw_buffer/config.cc +++ b/source/extensions/transport_sockets/raw_buffer/config.cc @@ -12,12 +12,14 @@ namespace Extensions { namespace TransportSockets { namespace RawBuffer { -Network::TransportSocketFactoryPtr UpstreamRawBufferSocketFactory::createTransportSocketFactory( +Network::UpstreamTransportSocketFactoryPtr +UpstreamRawBufferSocketFactory::createTransportSocketFactory( const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&) { return std::make_unique(); } -Network::TransportSocketFactoryPtr DownstreamRawBufferSocketFactory::createTransportSocketFactory( +Network::DownstreamTransportSocketFactoryPtr +DownstreamRawBufferSocketFactory::createTransportSocketFactory( const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&, const std::vector&) { return std::make_unique(); diff --git a/source/extensions/transport_sockets/raw_buffer/config.h b/source/extensions/transport_sockets/raw_buffer/config.h index b1353b6c2d601..5dd53a124789f 100644 --- a/source/extensions/transport_sockets/raw_buffer/config.h +++ b/source/extensions/transport_sockets/raw_buffer/config.h @@ -22,7 +22,7 @@ class UpstreamRawBufferSocketFactory : public Server::Configuration::UpstreamTransportSocketConfigFactory, public RawBufferSocketFactory { public: - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override; }; @@ -31,7 +31,7 @@ class DownstreamRawBufferSocketFactory : public Server::Configuration::DownstreamTransportSocketConfigFactory, public RawBufferSocketFactory { public: - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) override; diff --git a/source/extensions/transport_sockets/starttls/config.cc b/source/extensions/transport_sockets/starttls/config.cc index f4911dd3b89d1..138319429e185 100644 --- a/source/extensions/transport_sockets/starttls/config.cc +++ b/source/extensions/transport_sockets/starttls/config.cc @@ -7,7 +7,8 @@ namespace Extensions { namespace TransportSockets { namespace StartTls { -Network::TransportSocketFactoryPtr DownstreamStartTlsSocketFactory::createTransportSocketFactory( +Network::DownstreamTransportSocketFactoryPtr +DownstreamStartTlsSocketFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) { const auto& outer_config = MessageUtil::downcastAndValidate< @@ -17,19 +18,20 @@ Network::TransportSocketFactoryPtr DownstreamStartTlsSocketFactory::createTransp auto& raw_socket_config_factory = rawSocketConfigFactory(); auto& tls_socket_config_factory = tlsSocketConfigFactory(); - Network::TransportSocketFactoryPtr raw_socket_factory = + Network::DownstreamTransportSocketFactoryPtr raw_socket_factory = raw_socket_config_factory.createTransportSocketFactory(outer_config.cleartext_socket_config(), context, server_names); - Network::TransportSocketFactoryPtr tls_socket_factory = + Network::DownstreamTransportSocketFactoryPtr tls_socket_factory = tls_socket_config_factory.createTransportSocketFactory(outer_config.tls_socket_config(), context, server_names); - return std::make_unique(std::move(raw_socket_factory), - std::move(tls_socket_factory)); + return std::make_unique(std::move(raw_socket_factory), + std::move(tls_socket_factory)); } -Network::TransportSocketFactoryPtr UpstreamStartTlsSocketFactory::createTransportSocketFactory( +Network::UpstreamTransportSocketFactoryPtr +UpstreamStartTlsSocketFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context) { @@ -39,11 +41,11 @@ Network::TransportSocketFactoryPtr UpstreamStartTlsSocketFactory::createTranspor auto& raw_socket_config_factory = rawSocketConfigFactory(); auto& tls_socket_config_factory = tlsSocketConfigFactory(); - Network::TransportSocketFactoryPtr raw_socket_factory = + Network::UpstreamTransportSocketFactoryPtr raw_socket_factory = raw_socket_config_factory.createTransportSocketFactory(outer_config.cleartext_socket_config(), context); - Network::TransportSocketFactoryPtr tls_socket_factory = + Network::UpstreamTransportSocketFactoryPtr tls_socket_factory = tls_socket_config_factory.createTransportSocketFactory(outer_config.tls_socket_config(), context); diff --git a/source/extensions/transport_sockets/starttls/config.h b/source/extensions/transport_sockets/starttls/config.h index d60c9dbf0a50c..30de231dc8e94 100644 --- a/source/extensions/transport_sockets/starttls/config.h +++ b/source/extensions/transport_sockets/starttls/config.h @@ -36,7 +36,7 @@ class DownstreamStartTlsSocketFactory Server::Configuration::DownstreamTransportSocketConfigFactory, envoy::extensions::transport_sockets::starttls::v3::StartTlsConfig> { public: - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) override; @@ -47,7 +47,7 @@ class UpstreamStartTlsSocketFactory Server::Configuration::UpstreamTransportSocketConfigFactory, envoy::extensions::transport_sockets::starttls::v3::UpstreamStartTlsConfig> { public: - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override; }; diff --git a/source/extensions/transport_sockets/starttls/starttls_socket.cc b/source/extensions/transport_sockets/starttls/starttls_socket.cc index 93cd2e362ffd4..f7e26c5cfb119 100644 --- a/source/extensions/transport_sockets/starttls/starttls_socket.cc +++ b/source/extensions/transport_sockets/starttls/starttls_socket.cc @@ -30,6 +30,13 @@ Network::TransportSocketPtr StartTlsSocketFactory::createTransportSocket( transport_socket_options); } +Network::TransportSocketPtr +StartTlsDownstreamSocketFactory::createDownstreamTransportSocket() const { + return std::make_unique(raw_socket_factory_->createDownstreamTransportSocket(), + tls_socket_factory_->createDownstreamTransportSocket(), + nullptr); +} + } // namespace StartTls } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/starttls/starttls_socket.h b/source/extensions/transport_sockets/starttls/starttls_socket.h index ad3c14b1ff319..1bb21cc2fd1ca 100644 --- a/source/extensions/transport_sockets/starttls/starttls_socket.h +++ b/source/extensions/transport_sockets/starttls/starttls_socket.h @@ -105,13 +105,13 @@ class StartTlsSocket : public Network::TransportSocket, Logger::Loggable { public: ~StartTlsSocketFactory() override = default; - StartTlsSocketFactory(Network::TransportSocketFactoryPtr raw_socket_factory, - Network::TransportSocketFactoryPtr tls_socket_factory) + StartTlsSocketFactory(Network::UpstreamTransportSocketFactoryPtr raw_socket_factory, + Network::UpstreamTransportSocketFactoryPtr tls_socket_factory) : raw_socket_factory_(std::move(raw_socket_factory)), tls_socket_factory_(std::move(tls_socket_factory)) {} @@ -121,8 +121,26 @@ class StartTlsSocketFactory : public Network::CommonTransportSocketFactory, absl::string_view defaultServerNameIndication() const override { return ""; } private: - Network::TransportSocketFactoryPtr raw_socket_factory_; - Network::TransportSocketFactoryPtr tls_socket_factory_; + Network::UpstreamTransportSocketFactoryPtr raw_socket_factory_; + Network::UpstreamTransportSocketFactoryPtr tls_socket_factory_; +}; + +class StartTlsDownstreamSocketFactory : public Network::DownstreamTransportSocketFactory, + Logger::Loggable { +public: + ~StartTlsDownstreamSocketFactory() override = default; + + StartTlsDownstreamSocketFactory(Network::DownstreamTransportSocketFactoryPtr raw_socket_factory, + Network::DownstreamTransportSocketFactoryPtr tls_socket_factory) + : raw_socket_factory_(std::move(raw_socket_factory)), + tls_socket_factory_(std::move(tls_socket_factory)) {} + + Network::TransportSocketPtr createDownstreamTransportSocket() const override; + bool implementsSecureTransport() const override { return false; } + +private: + Network::DownstreamTransportSocketFactoryPtr raw_socket_factory_; + Network::DownstreamTransportSocketFactoryPtr tls_socket_factory_; }; } // namespace StartTls diff --git a/source/extensions/transport_sockets/tap/config.cc b/source/extensions/transport_sockets/tap/config.cc index 26255d00151d4..b84b173d2d36e 100644 --- a/source/extensions/transport_sockets/tap/config.cc +++ b/source/extensions/transport_sockets/tap/config.cc @@ -31,7 +31,8 @@ class SocketTapConfigFactoryImpl : public Extensions::Common::Tap::TapConfigFact TimeSource& time_source_; }; -Network::TransportSocketFactoryPtr UpstreamTapSocketConfigFactory::createTransportSocketFactory( +Network::UpstreamTransportSocketFactoryPtr +UpstreamTapSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context) { const auto& outer_config = @@ -50,7 +51,8 @@ Network::TransportSocketFactoryPtr UpstreamTapSocketConfigFactory::createTranspo context.mainThreadDispatcher(), std::move(inner_transport_factory)); } -Network::TransportSocketFactoryPtr DownstreamTapSocketConfigFactory::createTransportSocketFactory( +Network::DownstreamTransportSocketFactoryPtr +DownstreamTapSocketConfigFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) { const auto& outer_config = @@ -63,7 +65,7 @@ Network::TransportSocketFactoryPtr DownstreamTapSocketConfigFactory::createTrans outer_config.transport_socket(), context.messageValidationVisitor(), inner_config_factory); auto inner_transport_factory = inner_config_factory.createTransportSocketFactory( *inner_factory_config, context, server_names); - return std::make_unique( + return std::make_unique( outer_config, std::make_unique(context.mainThreadDispatcher().timeSource()), context.admin(), context.singletonManager(), context.threadLocal(), diff --git a/source/extensions/transport_sockets/tap/config.h b/source/extensions/transport_sockets/tap/config.h index c641cf45484ba..6bed5b5dd23dc 100644 --- a/source/extensions/transport_sockets/tap/config.h +++ b/source/extensions/transport_sockets/tap/config.h @@ -21,7 +21,7 @@ class UpstreamTapSocketConfigFactory : public Server::Configuration::UpstreamTransportSocketConfigFactory, public TapSocketConfigFactory { public: - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override; }; @@ -30,7 +30,7 @@ class DownstreamTapSocketConfigFactory : public Server::Configuration::DownstreamTransportSocketConfigFactory, public TapSocketConfigFactory { public: - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) override; diff --git a/source/extensions/transport_sockets/tap/tap.cc b/source/extensions/transport_sockets/tap/tap.cc index 982c55747c697..9ba0622512e78 100644 --- a/source/extensions/transport_sockets/tap/tap.cc +++ b/source/extensions/transport_sockets/tap/tap.cc @@ -51,7 +51,7 @@ TapSocketFactory::TapSocketFactory( Common::Tap::TapConfigFactoryPtr&& config_factory, Server::Admin& admin, Singleton::Manager& singleton_manager, ThreadLocal::SlotAllocator& tls, Event::Dispatcher& main_thread_dispatcher, - Network::TransportSocketFactoryPtr&& transport_socket_factory) + Network::UpstreamTransportSocketFactoryPtr&& transport_socket_factory) : ExtensionConfigBase(proto_config.common_config(), std::move(config_factory), admin, singleton_manager, tls, main_thread_dispatcher), PassthroughFactory(std::move(transport_socket_factory)) {} @@ -62,6 +62,21 @@ Network::TransportSocketPtr TapSocketFactory::createTransportSocket( transport_socket_factory_->createTransportSocket(options)); } +DownstreamTapSocketFactory::DownstreamTapSocketFactory( + const envoy::extensions::transport_sockets::tap::v3::Tap& proto_config, + Common::Tap::TapConfigFactoryPtr&& config_factory, Server::Admin& admin, + Singleton::Manager& singleton_manager, ThreadLocal::SlotAllocator& tls, + Event::Dispatcher& main_thread_dispatcher, + Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory) + : ExtensionConfigBase(proto_config.common_config(), std::move(config_factory), admin, + singleton_manager, tls, main_thread_dispatcher), + DownstreamPassthroughFactory(std::move(transport_socket_factory)) {} + +Network::TransportSocketPtr DownstreamTapSocketFactory::createDownstreamTransportSocket() const { + return std::make_unique(currentConfigHelper(), + transport_socket_factory_->createDownstreamTransportSocket()); +} + } // namespace Tap } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/tap/tap.h b/source/extensions/transport_sockets/tap/tap.h index a0e9b4459982c..0167a029a274b 100644 --- a/source/extensions/transport_sockets/tap/tap.h +++ b/source/extensions/transport_sockets/tap/tap.h @@ -34,13 +34,27 @@ class TapSocketFactory : public Common::Tap::ExtensionConfigBase, public Passthr Common::Tap::TapConfigFactoryPtr&& config_factory, Server::Admin& admin, Singleton::Manager& singleton_manager, ThreadLocal::SlotAllocator& tls, Event::Dispatcher& main_thread_dispatcher, - Network::TransportSocketFactoryPtr&& transport_socket_factory); + Network::UpstreamTransportSocketFactoryPtr&& transport_socket_factory); - // Network::TransportSocketFactory + // Network::UpstreamTransportSocketFactory Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; }; +class DownstreamTapSocketFactory : public Common::Tap::ExtensionConfigBase, + public DownstreamPassthroughFactory { +public: + DownstreamTapSocketFactory( + const envoy::extensions::transport_sockets::tap::v3::Tap& proto_config, + Common::Tap::TapConfigFactoryPtr&& config_factory, Server::Admin& admin, + Singleton::Manager& singleton_manager, ThreadLocal::SlotAllocator& tls, + Event::Dispatcher& main_thread_dispatcher, + Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory); + + // Network::UpstreamTransportSocketFactory + Network::TransportSocketPtr createDownstreamTransportSocket() const override; +}; + } // namespace Tap } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/tcp_stats/config.cc b/source/extensions/transport_sockets/tcp_stats/config.cc index 417d30891d9a1..7b428c51de40b 100644 --- a/source/extensions/transport_sockets/tcp_stats/config.cc +++ b/source/extensions/transport_sockets/tcp_stats/config.cc @@ -14,9 +14,7 @@ namespace TcpStats { TcpStatsSocketFactory::TcpStatsSocketFactory( Server::Configuration::TransportSocketFactoryContext& context, - const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config, - Network::TransportSocketFactoryPtr&& inner_factory) - : PassthroughFactory(std::move(inner_factory)) { + const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config) { #if defined(__linux__) config_ = std::make_shared(config, context.scope()); #else @@ -26,7 +24,13 @@ TcpStatsSocketFactory::TcpStatsSocketFactory( #endif } -Network::TransportSocketPtr TcpStatsSocketFactory::createTransportSocket( +UpstreamTcpStatsSocketFactory::UpstreamTcpStatsSocketFactory( + Server::Configuration::TransportSocketFactoryContext& context, + const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config, + Network::UpstreamTransportSocketFactoryPtr&& inner_factory) + : TcpStatsSocketFactory(context, config), PassthroughFactory(std::move(inner_factory)) {} + +Network::TransportSocketPtr UpstreamTcpStatsSocketFactory::createTransportSocket( Network::TransportSocketOptionsConstSharedPtr options) const { #if defined(__linux__) auto inner_socket = transport_socket_factory_->createTransportSocket(options); @@ -40,6 +44,26 @@ Network::TransportSocketPtr TcpStatsSocketFactory::createTransportSocket( #endif } +DownstreamTcpStatsSocketFactory::DownstreamTcpStatsSocketFactory( + Server::Configuration::TransportSocketFactoryContext& context, + const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config, + Network::DownstreamTransportSocketFactoryPtr&& inner_factory) + : TcpStatsSocketFactory(context, config), + DownstreamPassthroughFactory(std::move(inner_factory)) {} + +Network::TransportSocketPtr +DownstreamTcpStatsSocketFactory::createDownstreamTransportSocket() const { +#if defined(__linux__) + auto inner_socket = transport_socket_factory_->createDownstreamTransportSocket(); + if (inner_socket == nullptr) { + return nullptr; + } + return std::make_unique(config_, std::move(inner_socket)); +#else + return nullptr; +#endif +} + class TcpStatsConfigFactory : public virtual Server::Configuration::TransportSocketConfigFactory { public: std::string name() const override { return "envoy.transport_sockets.tcp_stats"; } @@ -52,7 +76,7 @@ class UpstreamTcpStatsConfigFactory : public Server::Configuration::UpstreamTransportSocketConfigFactory, public TcpStatsConfigFactory { public: - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override { const auto& outer_config = MessageUtil::downcastAndValidate< @@ -67,8 +91,8 @@ class UpstreamTcpStatsConfigFactory inner_config_factory); auto inner_transport_factory = inner_config_factory.createTransportSocketFactory(*inner_factory_config, context); - return std::make_unique(context, outer_config, - std::move(inner_transport_factory)); + return std::make_unique(context, outer_config, + std::move(inner_transport_factory)); } }; @@ -76,7 +100,7 @@ class DownstreamTcpStatsConfigFactory : public Server::Configuration::DownstreamTransportSocketConfigFactory, public TcpStatsConfigFactory { public: - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) override { @@ -92,8 +116,8 @@ class DownstreamTcpStatsConfigFactory inner_config_factory); auto inner_transport_factory = inner_config_factory.createTransportSocketFactory( *inner_factory_config, context, server_names); - return std::make_unique(context, outer_config, - std::move(inner_transport_factory)); + return std::make_unique(context, outer_config, + std::move(inner_transport_factory)); } }; diff --git a/source/extensions/transport_sockets/tcp_stats/config.h b/source/extensions/transport_sockets/tcp_stats/config.h index 5d008d4b9bf46..d40a7268ea862 100644 --- a/source/extensions/transport_sockets/tcp_stats/config.h +++ b/source/extensions/transport_sockets/tcp_stats/config.h @@ -11,21 +11,39 @@ namespace Extensions { namespace TransportSockets { namespace TcpStats { -class TcpStatsSocketFactory : public PassthroughFactory { +class TcpStatsSocketFactory { public: TcpStatsSocketFactory(Server::Configuration::TransportSocketFactoryContext& context, - const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config, - Network::TransportSocketFactoryPtr&& inner_factory); + const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config); - Network::TransportSocketPtr - createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; - -private: +protected: #if defined(__linux__) ConfigConstSharedPtr config_; #endif }; +class UpstreamTcpStatsSocketFactory : public TcpStatsSocketFactory, public PassthroughFactory { +public: + UpstreamTcpStatsSocketFactory( + Server::Configuration::TransportSocketFactoryContext& context, + const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config, + Network::UpstreamTransportSocketFactoryPtr&& inner_factory); + + Network::TransportSocketPtr + createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; +}; + +class DownstreamTcpStatsSocketFactory : public TcpStatsSocketFactory, + public DownstreamPassthroughFactory { +public: + DownstreamTcpStatsSocketFactory( + Server::Configuration::TransportSocketFactoryContext& context, + const envoy::extensions::transport_sockets::tcp_stats::v3::Config& config, + Network::DownstreamTransportSocketFactoryPtr&& inner_factory); + + Network::TransportSocketPtr createDownstreamTransportSocket() const override; +}; + } // namespace TcpStats } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/tls/config.cc b/source/extensions/transport_sockets/tls/config.cc index 4f4d46526c1f4..a7362d7813141 100644 --- a/source/extensions/transport_sockets/tls/config.cc +++ b/source/extensions/transport_sockets/tls/config.cc @@ -12,7 +12,7 @@ namespace Extensions { namespace TransportSockets { namespace Tls { -Network::TransportSocketFactoryPtr UpstreamSslSocketFactory::createTransportSocketFactory( +Network::UpstreamTransportSocketFactoryPtr UpstreamSslSocketFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context) { auto client_config = std::make_unique( @@ -31,7 +31,8 @@ ProtobufTypes::MessagePtr UpstreamSslSocketFactory::createEmptyConfigProto() { REGISTER_FACTORY(UpstreamSslSocketFactory, Server::Configuration::UpstreamTransportSocketConfigFactory){"tls"}; -Network::TransportSocketFactoryPtr DownstreamSslSocketFactory::createTransportSocketFactory( +Network::DownstreamTransportSocketFactoryPtr +DownstreamSslSocketFactory::createTransportSocketFactory( const Protobuf::Message& message, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) { auto server_config = std::make_unique( diff --git a/source/extensions/transport_sockets/tls/config.h b/source/extensions/transport_sockets/tls/config.h index f05066bd8141a..b9aa1f86ba9e5 100644 --- a/source/extensions/transport_sockets/tls/config.h +++ b/source/extensions/transport_sockets/tls/config.h @@ -21,7 +21,7 @@ class SslSocketConfigFactory : public virtual Server::Configuration::TransportSo class UpstreamSslSocketFactory : public Server::Configuration::UpstreamTransportSocketConfigFactory, public SslSocketConfigFactory { public: - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override; ProtobufTypes::MessagePtr createEmptyConfigProto() override; @@ -33,7 +33,7 @@ class DownstreamSslSocketFactory : public Server::Configuration::DownstreamTransportSocketConfigFactory, public SslSocketConfigFactory { public: - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context, const std::vector& server_names) override; diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 9ea64c3fc429a..7552b0e92d721 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -419,8 +419,7 @@ Envoy::Ssl::ClientContextSharedPtr ClientSslSocketFactory::sslCtx() { return ssl_ctx_; } -Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket( - Network::TransportSocketOptionsConstSharedPtr transport_socket_options) const { +Network::TransportSocketPtr ServerSslSocketFactory::createDownstreamTransportSocket() const { // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. @@ -430,8 +429,8 @@ Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket( ssl_ctx = ssl_ctx_; } if (ssl_ctx) { - return std::make_unique(std::move(ssl_ctx), InitialState::Server, - transport_socket_options, config_->createHandshaker()); + return std::make_unique(std::move(ssl_ctx), InitialState::Server, nullptr, + config_->createHandshaker()); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.downstream_context_secrets_not_ready_.inc(); diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 894cacea040a0..6458f7453c884 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -98,7 +98,7 @@ class SslSocket : public Network::TransportSocket, SslHandshakerImplSharedPtr info_; }; -class ClientSslSocketFactory : public Network::CommonTransportSocketFactory, +class ClientSslSocketFactory : public Network::CommonUpstreamTransportSocketFactory, public Secret::SecretCallbacks, Logger::Loggable { public: @@ -131,7 +131,7 @@ class ClientSslSocketFactory : public Network::CommonTransportSocketFactory, Envoy::Ssl::ClientContextSharedPtr ssl_ctx_ ABSL_GUARDED_BY(ssl_ctx_mu_); }; -class ServerSslSocketFactory : public Network::CommonTransportSocketFactory, +class ServerSslSocketFactory : public Network::DownstreamTransportSocketFactory, public Secret::SecretCallbacks, Logger::Loggable { public: @@ -141,10 +141,8 @@ class ServerSslSocketFactory : public Network::CommonTransportSocketFactory, ~ServerSslSocketFactory() override; - Network::TransportSocketPtr - createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; + Network::TransportSocketPtr createDownstreamTransportSocket() const override; bool implementsSecureTransport() const override; - absl::string_view defaultServerNameIndication() const override { return ""; } // Secret::SecretCallbacks void onAddOrUpdateSecret() override; diff --git a/source/server/active_stream_listener_base.cc b/source/server/active_stream_listener_base.cc index abc5ca7ce4fb7..c6aa7d1261c3d 100644 --- a/source/server/active_stream_listener_base.cc +++ b/source/server/active_stream_listener_base.cc @@ -40,7 +40,7 @@ void ActiveStreamListenerBase::newConnection(Network::ConnectionSocketPtr&& sock return; } stream_info->setFilterChainName(filter_chain->name()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto server_conn_ptr = dispatcher().createServerConnection( std::move(socket), std::move(transport_socket), *stream_info); if (const auto timeout = filter_chain->transportSocketConnectTimeout(); diff --git a/source/server/admin/admin.h b/source/server/admin/admin.h index 151d10ba6eee9..8fbde9502c6e1 100644 --- a/source/server/admin/admin.h +++ b/source/server/admin/admin.h @@ -439,7 +439,7 @@ class AdminImpl : public Admin, AdminFilterChain() {} // NOLINT(modernize-use-equals-default) // Network::FilterChain - const Network::TransportSocketFactory& transportSocketFactory() const override { + const Network::DownstreamTransportSocketFactory& transportSocketFactory() const override { return transport_socket_factory_; } diff --git a/source/server/filter_chain_manager_impl.h b/source/server/filter_chain_manager_impl.h index e3322602c35f1..0a57ff5054c4f 100644 --- a/source/server/filter_chain_manager_impl.h +++ b/source/server/filter_chain_manager_impl.h @@ -102,7 +102,7 @@ class PerFilterChainFactoryContextImpl : public Configuration::FilterChainFactor class FilterChainImpl : public Network::DrainableFilterChain { public: - FilterChainImpl(Network::TransportSocketFactoryPtr&& transport_socket_factory, + FilterChainImpl(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, std::vector&& filters_factory, std::chrono::milliseconds transport_socket_connect_timeout, absl::string_view name) @@ -111,7 +111,7 @@ class FilterChainImpl : public Network::DrainableFilterChain { transport_socket_connect_timeout_(transport_socket_connect_timeout), name_(name) {} // Network::FilterChain - const Network::TransportSocketFactory& transportSocketFactory() const override { + const Network::DownstreamTransportSocketFactory& transportSocketFactory() const override { return *transport_socket_factory_; } std::chrono::milliseconds transportSocketConnectTimeout() const override { @@ -132,7 +132,7 @@ class FilterChainImpl : public Network::DrainableFilterChain { private: Configuration::FilterChainFactoryContextPtr factory_context_; - const Network::TransportSocketFactoryPtr transport_socket_factory_; + const Network::DownstreamTransportSocketFactoryPtr transport_socket_factory_; const std::vector filters_factory_; const std::chrono::milliseconds transport_socket_connect_timeout_; const std::string name_; diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index a0947cf4c2a80..8336afff08011 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -538,7 +538,7 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { GrpcClientIntegrationTest::initialize(); } - Network::TransportSocketFactoryPtr createUpstreamSslContext() { + Network::DownstreamTransportSocketFactoryPtr createUpstreamSslContext() { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; auto* common_tls_context = tls_context.mutable_common_tls_context(); common_tls_context->add_alpn_protocols(Http::Utility::AlpnNames::get().Http2); diff --git a/test/common/network/raw_buffer_socket_test.cc b/test/common/network/raw_buffer_socket_test.cc index 46456985483fd..3f285053eeb72 100644 --- a/test/common/network/raw_buffer_socket_test.cc +++ b/test/common/network/raw_buffer_socket_test.cc @@ -9,7 +9,7 @@ namespace Envoy { namespace Network { TEST(RawBufferSocketFactory, RawBufferSocketFactory) { - TransportSocketFactoryPtr factory = Envoy::Network::Test::createRawBufferSocketFactory(); + UpstreamTransportSocketFactoryPtr factory = Envoy::Network::Test::createRawBufferSocketFactory(); EXPECT_FALSE(factory->implementsSecureTransport()); std::vector keys; factory->hashKey(keys, nullptr); diff --git a/test/common/quic/envoy_quic_server_session_test.cc b/test/common/quic/envoy_quic_server_session_test.cc index d5ca034caa8c0..f33cf847058e7 100644 --- a/test/common/quic/envoy_quic_server_session_test.cc +++ b/test/common/quic/envoy_quic_server_session_test.cc @@ -123,7 +123,7 @@ class EnvoyQuicTestCryptoServerStreamFactory : public EnvoyQuicCryptoServerStrea const quic::QuicCryptoServerConfig* crypto_config, quic::QuicCompressedCertsCache* compressed_certs_cache, quic::QuicSession* session, quic::QuicCryptoServerStreamBase::Helper* helper, - OptRef /*transport_socket_factory*/, + OptRef /*transport_socket_factory*/, Event::Dispatcher& /*dispatcher*/) override { switch (session->connection()->version().handshake_protocol) { case quic::PROTOCOL_QUIC_CRYPTO: diff --git a/test/common/quic/quic_transport_socket_factory_test.cc b/test/common/quic/quic_transport_socket_factory_test.cc index cfa47fd6b84f5..38af57aacf46c 100644 --- a/test/common/quic/quic_transport_socket_factory_test.cc +++ b/test/common/quic/quic_transport_socket_factory_test.cc @@ -23,7 +23,7 @@ class QuicServerTransportSocketFactoryConfigTest : public Event::TestUsingSimula void verifyQuicServerTransportSocketFactory(std::string yaml, bool expect_early_data) { envoy::extensions::transport_sockets::quic::v3::QuicDownstreamTransport proto_config; TestUtility::loadFromYaml(yaml, proto_config); - Network::TransportSocketFactoryPtr transport_socket_factory = + Network::DownstreamTransportSocketFactoryPtr transport_socket_factory = config_factory_.createTransportSocketFactory(proto_config, context_, {}); EXPECT_EQ(expect_early_data, static_cast(*transport_socket_factory) diff --git a/test/common/upstream/cluster_manager_impl_test.cc b/test/common/upstream/cluster_manager_impl_test.cc index 97c72002a244c..b93308f38ead9 100644 --- a/test/common/upstream/cluster_manager_impl_test.cc +++ b/test/common/upstream/cluster_manager_impl_test.cc @@ -466,7 +466,7 @@ class AlpnTestConfigFactory : public Envoy::Extensions::TransportSockets::RawBuffer::UpstreamRawBufferSocketFactory { public: std::string name() const override { return "envoy.transport_sockets.alpn"; } - Network::TransportSocketFactoryPtr + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&) override { return std::make_unique(); diff --git a/test/common/upstream/hds_test.cc b/test/common/upstream/hds_test.cc index 84efa07d3848e..2643131ab9144 100644 --- a/test/common/upstream/hds_test.cc +++ b/test/common/upstream/hds_test.cc @@ -589,7 +589,7 @@ TEST_F(HdsTest, TestSocketContext) { params.validation_visitor_, params.api_, params.options_, params.access_log_manager_); // Create a mock socket_factory for the scope of this unit test. - std::unique_ptr socket_factory = + std::unique_ptr socket_factory = std::make_unique(); // set socket_matcher object in test scope. @@ -1041,7 +1041,7 @@ TEST_F(HdsTest, TestUpdateSocketContext) { params.validation_visitor_, params.api_, params.options_, params.access_log_manager_); // Create a mock socket_factory for the scope of this unit test. - std::unique_ptr socket_factory = + std::unique_ptr socket_factory = std::make_unique(); // set socket_matcher object in test scope. diff --git a/test/common/upstream/health_checker_impl_test.cc b/test/common/upstream/health_checker_impl_test.cc index 0aedf94a6e5e7..ab2288823a0d5 100644 --- a/test/common/upstream/health_checker_impl_test.cc +++ b/test/common/upstream/health_checker_impl_test.cc @@ -1100,7 +1100,7 @@ TEST_F(HttpHealthCheckerImplTest, TlsOptions) { auto socket_factory = new Network::MockTransportSocketFactory(); EXPECT_CALL(*socket_factory, implementsSecureTransport()).WillRepeatedly(Return(true)); auto transport_socket_match = new NiceMock( - Network::TransportSocketFactoryPtr(socket_factory)); + Network::UpstreamTransportSocketFactoryPtr(socket_factory)); cluster_->info_->transport_socket_matcher_.reset(transport_socket_match); EXPECT_CALL(*socket_factory, createTransportSocket(ApplicationProtocolListEq("http1"))); diff --git a/test/common/upstream/transport_socket_matcher_test.cc b/test/common/upstream/transport_socket_matcher_test.cc index 857898f8b0562..63ae5efbe1f68 100644 --- a/test/common/upstream/transport_socket_matcher_test.cc +++ b/test/common/upstream/transport_socket_matcher_test.cc @@ -28,7 +28,7 @@ namespace Envoy { namespace Upstream { namespace { -class FakeTransportSocketFactory : public Network::TransportSocketFactory { +class FakeTransportSocketFactory : public Network::UpstreamTransportSocketFactory { public: MOCK_METHOD(bool, implementsSecureTransport, (), (const)); MOCK_METHOD(bool, supportsAlpn, (), (const)); @@ -49,7 +49,7 @@ class FakeTransportSocketFactory : public Network::TransportSocketFactory { }; class FooTransportSocketFactory - : public Network::TransportSocketFactory, + : public Network::UpstreamTransportSocketFactory, public Server::Configuration::UpstreamTransportSocketConfigFactory, Logger::Loggable { public: @@ -60,7 +60,7 @@ class FooTransportSocketFactory (const)); MOCK_METHOD(absl::string_view, defaultServerNameIndication, (), (const)); - Network::TransportSocketFactoryPtr + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& proto, Server::Configuration::TransportSocketFactoryContext&) override { const auto& node = dynamic_cast(proto); @@ -109,7 +109,7 @@ class TransportSocketMatcherTest : public testing::Test { TransportSocketMatcherPtr matcher_; NiceMock mock_factory_context_; - Network::TransportSocketFactoryPtr mock_default_factory_; + Network::UpstreamTransportSocketFactoryPtr mock_default_factory_; Stats::IsolatedStoreImpl stats_store_; Stats::ScopeSharedPtr stats_scope_; }; diff --git a/test/extensions/bootstrap/internal_listener/active_internal_listener_test.cc b/test/extensions/bootstrap/internal_listener/active_internal_listener_test.cc index 223f6acc6d797..c8034bcb797ee 100644 --- a/test/extensions/bootstrap/internal_listener/active_internal_listener_test.cc +++ b/test/extensions/bootstrap/internal_listener/active_internal_listener_test.cc @@ -165,7 +165,7 @@ TEST_F(ActiveInternalListenerTest, AcceptSocketAndCreateNetworkFilter) { EXPECT_CALL(*test_listener_filter, destroy_()); auto filter_factory_callback = std::make_shared>(); filter_chain_ = std::make_shared>(); - auto transport_socket_factory = Network::Test::createRawBufferSocketFactory(); + auto transport_socket_factory = Network::Test::createRawBufferDownstreamSocketFactory(); EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(filter_chain_.get())); EXPECT_CALL(*filter_chain_, transportSocketFactory) @@ -211,7 +211,7 @@ TEST_F(ActiveInternalListenerTest, DestroyListenerCloseAllConnections) { auto filter_factory_callback = std::make_shared>(); filter_chain_ = std::make_shared>(); - auto transport_socket_factory = Network::Test::createRawBufferSocketFactory(); + auto transport_socket_factory = Network::Test::createRawBufferDownstreamSocketFactory(); EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) .WillRepeatedly(Invoke([&](Network::ListenerFilterManager&) -> bool { return true; })); @@ -238,8 +238,8 @@ class ConnectionHandlerTest : public testing::Test, protected Logger::Loggable>()), access_log_(std::make_shared()) { ON_CALL(*filter_chain_, transportSocketFactory) - .WillByDefault(ReturnPointee(std::shared_ptr{ - Network::Test::createRawBufferSocketFactory()})); + .WillByDefault(ReturnPointee(std::shared_ptr{ + Network::Test::createRawBufferDownstreamSocketFactory()})); ON_CALL(*filter_chain_, networkFilterFactories) .WillByDefault(ReturnPointee(std::make_shared>())); ON_CALL(*listener_filter_matcher_, matches(_)).WillByDefault(Return(false)); diff --git a/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc b/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc index 41c96bd60d57d..b8910b290f3a0 100644 --- a/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc +++ b/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc @@ -80,12 +80,14 @@ name: alternate_protocols_cache TRY_ASSERT_MAIN_THREAD { // Make the first upstream HTTP/2 auto http2_config = configWithType(Http::CodecType::HTTP2); - Network::TransportSocketFactoryPtr http2_factory = createUpstreamTlsContext(http2_config); + Network::DownstreamTransportSocketFactoryPtr http2_factory = + createUpstreamTlsContext(http2_config); addFakeUpstream(std::move(http2_factory), Http::CodecType::HTTP2); // Make the next upstream is HTTP/3 auto http3_config = configWithType(Http::CodecType::HTTP3); - Network::TransportSocketFactoryPtr http3_factory = createUpstreamTlsContext(http3_config); + Network::DownstreamTransportSocketFactoryPtr http3_factory = + createUpstreamTlsContext(http3_config); // If the UDP port is in use, this will throw an exception and get caught below. fake_upstreams_.emplace_back(std::make_unique( std::move(http3_factory), fake_upstreams_[0]->localAddress()->ip()->port(), version_, @@ -288,11 +290,11 @@ class MixedUpstreamIntegrationTest : public FilterIntegrationTest { if (use_http2_) { auto config = configWithType(Http::CodecType::HTTP2); - Network::TransportSocketFactoryPtr factory = createUpstreamTlsContext(config); + Network::DownstreamTransportSocketFactoryPtr factory = createUpstreamTlsContext(config); addFakeUpstream(std::move(factory), Http::CodecType::HTTP2); } else { auto config = configWithType(Http::CodecType::HTTP3); - Network::TransportSocketFactoryPtr factory = createUpstreamTlsContext(config); + Network::DownstreamTransportSocketFactoryPtr factory = createUpstreamTlsContext(config); addFakeUpstream(std::move(factory), Http::CodecType::HTTP3); writeFile(); } diff --git a/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_test.cc b/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_test.cc index c41bbdbdacab6..2d101b65ee6fa 100644 --- a/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_test.cc +++ b/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_test.cc @@ -41,7 +41,7 @@ class ProxyFilterTest : public testing::Test, void setupSocketMatcher() { cm_.initializeThreadLocalClusters({"fake_cluster"}); transport_socket_match_ = new NiceMock( - Network::TransportSocketFactoryPtr(transport_socket_factory_)); + Network::UpstreamTransportSocketFactoryPtr(transport_socket_factory_)); cm_.thread_local_cluster_.cluster_.info_->transport_socket_matcher_.reset( transport_socket_match_); } diff --git a/test/extensions/filters/http/ratelimit/ratelimit_integration_test.cc b/test/extensions/filters/http/ratelimit/ratelimit_integration_test.cc index b9f538506f41e..6c9290bcbe767 100644 --- a/test/extensions/filters/http/ratelimit/ratelimit_integration_test.cc +++ b/test/extensions/filters/http/ratelimit/ratelimit_integration_test.cc @@ -38,7 +38,7 @@ class RatelimitIntegrationTest : public Grpc::GrpcClientIntegrationParamTest, // Add autonomous upstream. auto endpoint = upstream_address_fn_(0); fake_upstreams_.emplace_back(new AutonomousUpstream( - Network::Test::createRawBufferSocketFactory(), endpoint->ip()->port(), + Network::Test::createRawBufferDownstreamSocketFactory(), endpoint->ip()->port(), endpoint->ip()->version(), upstreamConfig(), true)); // Add ratelimit upstream. diff --git a/test/extensions/filters/http/router/auto_sni_integration_test.cc b/test/extensions/filters/http/router/auto_sni_integration_test.cc index 4c8307125d1a7..4c3584858f5db 100644 --- a/test/extensions/filters/http/router/auto_sni_integration_test.cc +++ b/test/extensions/filters/http/router/auto_sni_integration_test.cc @@ -49,7 +49,7 @@ class AutoSniIntegrationTest : public testing::TestWithParamadd_tls_certificates(); diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_integration_test.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_integration_test.cc index 3631dc7132130..33cc06c17b5dc 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_integration_test.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_integration_test.cc @@ -133,7 +133,7 @@ class TlsInspectorIntegrationTest : public testing::TestWithParam context_manager_; - Network::TransportSocketFactoryPtr context_; + Network::UpstreamTransportSocketFactoryPtr context_; ConnectionStatusCallbacks connect_callbacks_; testing::NiceMock secret_manager_; Network::ClientConnectionPtr client_; diff --git a/test/extensions/transport_sockets/alts/alts_integration_test.cc b/test/extensions/transport_sockets/alts/alts_integration_test.cc index 69305cdf0c66b..e85f80f8e53ae 100644 --- a/test/extensions/transport_sockets/alts/alts_integration_test.cc +++ b/test/extensions/transport_sockets/alts/alts_integration_test.cc @@ -240,7 +240,7 @@ class AltsIntegrationTestBase : public Event::TestUsingSimulatedTime, std::unique_ptr fake_handshaker_server_; ConditionalInitializer fake_handshaker_server_ci_; int fake_handshaker_server_port_{}; - Network::TransportSocketFactoryPtr client_alts_; + Network::UpstreamTransportSocketFactoryPtr client_alts_; TsiSocket* client_tsi_socket_{nullptr}; bool capturing_handshaker_; CapturingHandshakerService* capturing_handshaker_service_; diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index 8dbf6f4c270bb..b14bf73e57a90 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -906,7 +906,7 @@ class TsiSocketFactoryTest : public testing::Test { socket_factory_ = std::make_unique(handshaker_factory, nullptr); } - Network::TransportSocketFactoryPtr socket_factory_; + Network::UpstreamTransportSocketFactoryPtr socket_factory_; }; TEST_F(TsiSocketFactoryTest, CreateTransportSocket) { diff --git a/test/extensions/transport_sockets/common/passthrough_test.cc b/test/extensions/transport_sockets/common/passthrough_test.cc index 7b055f6f10775..bbd75d3aeec6f 100644 --- a/test/extensions/transport_sockets/common/passthrough_test.cc +++ b/test/extensions/transport_sockets/common/passthrough_test.cc @@ -98,7 +98,7 @@ TEST_F(PassthroughTest, ConfigureInitialCongestionWindowDefersToInnerSocket) { TEST(PassthroughFactoryTest, TestDelegation) { auto inner_factory_ptr = std::make_unique>(); Network::MockTransportSocketFactory* inner_factory = inner_factory_ptr.get(); - Network::TransportSocketFactoryPtr factory{std::move(inner_factory_ptr)}; + Network::UpstreamTransportSocketFactoryPtr factory{std::move(inner_factory_ptr)}; { EXPECT_CALL(*inner_factory, implementsSecureTransport()); @@ -116,6 +116,26 @@ TEST(PassthroughFactoryTest, TestDelegation) { } } +class DownstreamTestFactory : public DownstreamPassthroughFactory { +public: + DownstreamTestFactory(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory) + : DownstreamPassthroughFactory(std::move(transport_socket_factory)) {} + + Network::TransportSocketPtr createDownstreamTransportSocket() const override { return nullptr; } +}; + +TEST(PassthroughFactoryTest, TestDownstreamDelegation) { + auto inner_factory_ptr = + std::make_unique>(); + Network::MockDownstreamTransportSocketFactory* inner_factory = inner_factory_ptr.get(); + auto factory = std::make_unique(std::move(inner_factory_ptr)); + + { + EXPECT_CALL(*inner_factory, implementsSecureTransport()); + factory->implementsSecureTransport(); + } +} + } // namespace } // namespace TransportSockets } // namespace Extensions diff --git a/test/extensions/transport_sockets/starttls/starttls_integration_test.cc b/test/extensions/transport_sockets/starttls/starttls_integration_test.cc index 5d0225e8d3fde..da59a7d71f245 100644 --- a/test/extensions/transport_sockets/starttls/starttls_integration_test.cc +++ b/test/extensions/transport_sockets/starttls/starttls_integration_test.cc @@ -160,8 +160,8 @@ class StartTlsIntegrationTest : public testing::TestWithParam tls_context_manager_; - Network::TransportSocketFactoryPtr tls_context_; - Network::TransportSocketFactoryPtr cleartext_context_; + Network::UpstreamTransportSocketFactoryPtr tls_context_; + Network::UpstreamTransportSocketFactoryPtr cleartext_context_; MockWatermarkBuffer* client_write_buffer_{nullptr}; ConnectionStatusCallbacks connect_callbacks_; @@ -203,7 +203,7 @@ void StartTlsIntegrationTest::initialize() { auto factory = std::make_unique(); - cleartext_context_ = Network::TransportSocketFactoryPtr{ + cleartext_context_ = Network::UpstreamTransportSocketFactoryPtr{ factory->createTransportSocketFactory(*config, factory_context_)}; // Setup factories and contexts for tls transport socket. diff --git a/test/extensions/transport_sockets/starttls/starttls_socket_test.cc b/test/extensions/transport_sockets/starttls/starttls_socket_test.cc index 09aa5a1548b40..29267aa1163f8 100644 --- a/test/extensions/transport_sockets/starttls/starttls_socket_test.cc +++ b/test/extensions/transport_sockets/starttls/starttls_socket_test.cc @@ -177,8 +177,8 @@ TEST(StartTls, BasicFactoryTest) { NiceMock* ssl_factory = new NiceMock; std::unique_ptr factory = std::make_unique( - Network::TransportSocketFactoryPtr(raw_buffer_factory), - Network::TransportSocketFactoryPtr(ssl_factory)); + Network::UpstreamTransportSocketFactoryPtr(raw_buffer_factory), + Network::UpstreamTransportSocketFactoryPtr(ssl_factory)); ASSERT_FALSE(factory->implementsSecureTransport()); std::vector key; factory->hashKey(key, nullptr); diff --git a/test/extensions/transport_sockets/starttls/upstream_starttls_integration_test.cc b/test/extensions/transport_sockets/starttls/upstream_starttls_integration_test.cc index 5126e199d2d47..c5a2aeba79dbc 100644 --- a/test/extensions/transport_sockets/starttls/upstream_starttls_integration_test.cc +++ b/test/extensions/transport_sockets/starttls/upstream_starttls_integration_test.cc @@ -173,7 +173,7 @@ class StartTlsIntegrationTest : public testing::TestWithParam tls_context_manager_; - Network::TransportSocketFactoryPtr tls_context_; + Network::DownstreamTransportSocketFactoryPtr tls_context_; // Technically unused. StreamInfo::StreamInfoImpl stream_info_; @@ -236,7 +236,7 @@ void StartTlsIntegrationTest::initialize() { auto cfg = std::make_unique( downstream_tls_context, mock_factory_ctx); static auto* client_stats_store = new Stats::TestIsolatedStoreImpl(); - tls_context_ = Network::TransportSocketFactoryPtr{ + tls_context_ = Network::DownstreamTransportSocketFactoryPtr{ new Extensions::TransportSockets::Tls::ServerSslSocketFactory( std::move(cfg), *tls_context_manager_, *client_stats_store, {})}; @@ -270,10 +270,7 @@ TEST_P(StartTlsIntegrationTest, SwitchToTlsFromClient) { ASSERT_TRUE(fake_upstream_connection->waitForData(5)); // Create TLS transport socket and install it in fake_upstream. - Network::TransportSocketPtr ts = - tls_context_->createTransportSocket(std::make_shared( - absl::string_view(""), std::vector(), - std::vector{"envoyalpn"})); + Network::TransportSocketPtr ts = tls_context_->createDownstreamTransportSocket(); // Synchronization object used to suspend execution // until dispatcher completes transport socket conversion. diff --git a/test/extensions/transport_sockets/tcp_stats/tcp_stats_test.cc b/test/extensions/transport_sockets/tcp_stats/tcp_stats_test.cc index 4ba61492ce22a..e5453ec814070 100644 --- a/test/extensions/transport_sockets/tcp_stats/tcp_stats_test.cc +++ b/test/extensions/transport_sockets/tcp_stats/tcp_stats_test.cc @@ -260,13 +260,13 @@ class TcpStatsSocketFactoryTest : public testing::Test { envoy::extensions::transport_sockets::tcp_stats::v3::Config proto_config; auto inner_factory = std::make_unique>(); inner_factory_ = inner_factory.get(); - factory_ = - std::make_unique(context_, proto_config, std::move(inner_factory)); + factory_ = std::make_unique(context_, proto_config, + std::move(inner_factory)); } NiceMock context_; NiceMock* inner_factory_; - std::unique_ptr factory_; + std::unique_ptr factory_; }; // Test createTransportSocket returns nullptr if inner call returns nullptr diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index 33ae1fc995afd..0e45d1bc5b3a8 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -378,7 +378,7 @@ void testUtil(const TestUtilOptions& options) { NiceMock stream_info; EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { - auto ssl_socket = server_ssl_socket_factory.createTransportSocket(nullptr); + auto ssl_socket = server_ssl_socket_factory.createDownstreamTransportSocket(); // configureInitialCongestionWindow is an unimplemented empty function, this is just to // increase code coverage. ssl_socket->configureInitialCongestionWindow(100, std::chrono::microseconds(123)); @@ -736,7 +736,7 @@ void testUtilV2(const TestUtilOptionsV2& options) { : options.clientCtxProto().sni(); socket->setRequestedServerName(sni); Network::TransportSocketPtr transport_socket = - server_ssl_socket_factory.createTransportSocket(nullptr); + server_ssl_socket_factory.createDownstreamTransportSocket(); EXPECT_FALSE(transport_socket->startSecureTransport()); server_connection = dispatcher->createServerConnection( std::move(socket), std::move(transport_socket), stream_info); @@ -2592,7 +2592,7 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); Buffer::OwnedImpl data("hello"); @@ -2662,7 +2662,7 @@ TEST_P(SslSocketTest, HalfClose) { EXPECT_CALL(listener_callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->enableHalfClose(true); server_connection->addReadFilter(server_read_filter); @@ -2744,7 +2744,7 @@ TEST_P(SslSocketTest, ShutdownWithCloseNotify) { EXPECT_CALL(listener_callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->enableHalfClose(true); server_connection->addReadFilter(server_read_filter); @@ -2832,7 +2832,7 @@ TEST_P(SslSocketTest, ShutdownWithoutCloseNotify) { EXPECT_CALL(listener_callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->enableHalfClose(true); server_connection->addReadFilter(server_read_filter); @@ -2948,7 +2948,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); })); @@ -3036,13 +3036,13 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, StreamInfo::StreamInfoImpl stream_info(time_system, nullptr); EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { - Network::TransportSocketFactory& tsf = + Network::DownstreamTransportSocketFactory& tsf = socket->connectionInfoProvider().localAddress() == socket1->connectionInfoProvider().localAddress() ? server_ssl_socket_factory1 : server_ssl_socket_factory2; server_connection = dispatcher->createServerConnection( - std::move(socket), tsf.createTransportSocket(nullptr), stream_info); + std::move(socket), tsf.createDownstreamTransportSocket(), stream_info); })); EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) @@ -3081,13 +3081,13 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, StreamInfo::StreamInfoImpl stream_info2(time_system, nullptr); EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { - Network::TransportSocketFactory& tsf = + Network::DownstreamTransportSocketFactory& tsf = socket->connectionInfoProvider().localAddress() == socket1->connectionInfoProvider().localAddress() ? server_ssl_socket_factory1 : server_ssl_socket_factory2; server_connection = dispatcher->createServerConnection( - std::move(socket), tsf.createTransportSocket(nullptr), stream_info2); + std::move(socket), tsf.createDownstreamTransportSocket(), stream_info2); server_connection->addConnectionCallbacks(server_connection_callbacks); })); @@ -3177,7 +3177,7 @@ void testSupportForStatelessSessionResumption(const std::string& server_ctx_yaml EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info); const SslHandshakerImpl* ssl_socket = @@ -3781,13 +3781,13 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& accepted_socket) -> void { - Network::TransportSocketFactory& tsf = + Network::DownstreamTransportSocketFactory& tsf = accepted_socket->connectionInfoProvider().localAddress() == socket->connectionInfoProvider().localAddress() ? server_ssl_socket_factory : server2_ssl_socket_factory; server_connection = dispatcher_->createServerConnection( - std::move(accepted_socket), tsf.createTransportSocket(nullptr), stream_info_); + std::move(accepted_socket), tsf.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); })); @@ -3823,13 +3823,13 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& accepted_socket) -> void { - Network::TransportSocketFactory& tsf = + Network::DownstreamTransportSocketFactory& tsf = accepted_socket->connectionInfoProvider().localAddress() == socket->connectionInfoProvider().localAddress() ? server_ssl_socket_factory : server2_ssl_socket_factory; server_connection = dispatcher_->createServerConnection( - std::move(accepted_socket), tsf.createTransportSocket(nullptr), stream_info_); + std::move(accepted_socket), tsf.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); })); EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)); @@ -3914,7 +3914,7 @@ void SslSocketTest::testClientSessionResumption(const std::string& server_ctx_ya EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); })); @@ -3960,7 +3960,7 @@ void SslSocketTest::testClientSessionResumption(const std::string& server_ctx_ya EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); })); @@ -4144,7 +4144,7 @@ TEST_P(SslSocketTest, SslError) { EXPECT_CALL(callbacks, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory.createDownstreamTransportSocket(), stream_info_); server_connection->addConnectionCallbacks(server_connection_callbacks); })); @@ -5129,7 +5129,7 @@ TEST_P(SslSocketTest, DownstreamNotReadySslSocket) { ContextManagerImpl manager(time_system_); ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, stats_store, std::vector{}); - auto transport_socket = server_ssl_socket_factory.createTransportSocket(nullptr); + auto transport_socket = server_ssl_socket_factory.createDownstreamTransportSocket(); EXPECT_FALSE(transport_socket->startSecureTransport()); // Noop transport_socket->configureInitialCongestionWindow(200, std::chrono::microseconds(223)); // Noop EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); @@ -5249,7 +5249,7 @@ class SslReadBufferLimitTest : public SslSocketTest { EXPECT_CALL(listener_callbacks_, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection_ = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory_->createDownstreamTransportSocket(), stream_info_); server_connection_->setBufferLimits(read_buffer_limit); server_connection_->addConnectionCallbacks(server_callbacks_); @@ -5325,7 +5325,7 @@ class SslReadBufferLimitTest : public SslSocketTest { EXPECT_CALL(listener_callbacks_, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection_ = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory_->createDownstreamTransportSocket(), stream_info_); server_connection_->setBufferLimits(read_buffer_limit); server_connection_->addConnectionCallbacks(server_callbacks_); @@ -5393,11 +5393,11 @@ class SslReadBufferLimitTest : public SslSocketTest { NiceMock runtime_; envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext downstream_tls_context_; std::unique_ptr manager_; - Network::TransportSocketFactoryPtr server_ssl_socket_factory_; + Network::DownstreamTransportSocketFactoryPtr server_ssl_socket_factory_; Network::ListenerPtr listener_; envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext upstream_tls_context_; Envoy::Ssl::ClientContextSharedPtr client_ctx_; - Network::TransportSocketFactoryPtr client_ssl_socket_factory_; + Network::UpstreamTransportSocketFactoryPtr client_ssl_socket_factory_; Network::ClientConnectionPtr client_connection_; Network::TransportSocket* client_transport_socket_{}; Network::ConnectionPtr server_connection_; @@ -5445,7 +5445,7 @@ TEST_P(SslReadBufferLimitTest, TestBind) { EXPECT_CALL(listener_callbacks_, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection_ = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory_->createDownstreamTransportSocket(), stream_info_); server_connection_->addConnectionCallbacks(server_callbacks_); server_connection_->addReadFilter(read_filter_); @@ -5476,7 +5476,7 @@ TEST_P(SslReadBufferLimitTest, SmallReadsIntoSameSlice) { EXPECT_CALL(listener_callbacks_, onAccept_(_)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { server_connection_ = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr), + std::move(socket), server_ssl_socket_factory_->createDownstreamTransportSocket(), stream_info_); server_connection_->setBufferLimits(read_buffer_limit); server_connection_->addConnectionCallbacks(server_callbacks_); diff --git a/test/integration/alpn_selection_integration_test.cc b/test/integration/alpn_selection_integration_test.cc index da784ec70e3cf..5cb436295cb10 100644 --- a/test/integration/alpn_selection_integration_test.cc +++ b/test/integration/alpn_selection_integration_test.cc @@ -50,7 +50,7 @@ name: tls HttpIntegrationTest::initialize(); } - Network::TransportSocketFactoryPtr createUpstreamSslContext() { + Network::DownstreamTransportSocketFactoryPtr createUpstreamSslContext() { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; const std::string yaml = absl::StrFormat( R"EOF( diff --git a/test/integration/autonomous_upstream.h b/test/integration/autonomous_upstream.h index aa0e15fb1baa1..fe87b3b335f5f 100644 --- a/test/integration/autonomous_upstream.h +++ b/test/integration/autonomous_upstream.h @@ -60,7 +60,7 @@ using AutonomousHttpConnectionPtr = std::unique_ptr; // An upstream which creates AutonomousHttpConnection for new incoming connections. class AutonomousUpstream : public FakeUpstream { public: - AutonomousUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, + AutonomousUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, const Network::Address::InstanceConstSharedPtr& address, const FakeUpstreamConfig& config, bool allow_incomplete_streams) : FakeUpstream(std::move(transport_socket_factory), address, config), @@ -69,9 +69,9 @@ class AutonomousUpstream : public FakeUpstream { response_headers_(std::make_unique( Http::TestResponseHeaderMapImpl({{":status", "200"}}))) {} - AutonomousUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, uint32_t port, - Network::Address::IpVersion version, const FakeUpstreamConfig& config, - bool allow_incomplete_streams) + AutonomousUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, + uint32_t port, Network::Address::IpVersion version, + const FakeUpstreamConfig& config, bool allow_incomplete_streams) : FakeUpstream(std::move(transport_socket_factory), port, version, config), allow_incomplete_streams_(allow_incomplete_streams), response_trailers_(std::make_unique()), diff --git a/test/integration/base_integration_test.cc b/test/integration/base_integration_test.cc index 097d1147f1178..2534b6dfedf91 100644 --- a/test/integration/base_integration_test.cc +++ b/test/integration/base_integration_test.cc @@ -109,7 +109,7 @@ void BaseIntegrationTest::initialize() { } } -Network::TransportSocketFactoryPtr +Network::DownstreamTransportSocketFactoryPtr BaseIntegrationTest::createUpstreamTlsContext(const FakeUpstreamConfig& upstream_config) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; const std::string yaml = absl::StrFormat( @@ -156,9 +156,9 @@ void BaseIntegrationTest::createUpstreams() { } void BaseIntegrationTest::createUpstream(Network::Address::InstanceConstSharedPtr endpoint, FakeUpstreamConfig& config) { - Network::TransportSocketFactoryPtr factory = upstream_tls_ - ? createUpstreamTlsContext(config) - : Network::Test::createRawBufferSocketFactory(); + Network::DownstreamTransportSocketFactoryPtr factory = + upstream_tls_ ? createUpstreamTlsContext(config) + : Network::Test::createRawBufferDownstreamSocketFactory(); if (autonomous_upstream_) { fake_upstreams_.emplace_back(new AutonomousUpstream(std::move(factory), endpoint, config, autonomous_allow_incomplete_streams_)); diff --git a/test/integration/base_integration_test.h b/test/integration/base_integration_test.h index 43c7a48a7706a..2a75c778537ca 100644 --- a/test/integration/base_integration_test.h +++ b/test/integration/base_integration_test.h @@ -357,8 +357,9 @@ class BaseIntegrationTest : protected Logger::Loggable { return *fake_upstreams_.back(); } - FakeUpstream& addFakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, - Http::CodecType type) { + FakeUpstream& + addFakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, + Http::CodecType type) { auto config = configWithType(type); fake_upstreams_.emplace_back( std::make_unique(std::move(transport_socket_factory), 0, version_, config)); @@ -446,7 +447,7 @@ class BaseIntegrationTest : protected Logger::Loggable { bool use_lds_{true}; // Use the integration framework's LDS set up. bool upstream_tls_{false}; - Network::TransportSocketFactoryPtr + Network::DownstreamTransportSocketFactoryPtr createUpstreamTlsContext(const FakeUpstreamConfig& upstream_config); testing::NiceMock factory_context_; Extensions::TransportSockets::Tls::ContextManagerImpl context_manager_{timeSystem()}; diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index 5312bdd4c7141..b432850128662 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -506,7 +506,7 @@ AssertionResult FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_d } FakeUpstream::FakeUpstream(const std::string& uds_path, const FakeUpstreamConfig& config) - : FakeUpstream(Network::Test::createRawBufferSocketFactory(), + : FakeUpstream(Network::Test::createRawBufferDownstreamSocketFactory(), Network::SocketPtr{new Network::UdsListenSocket( std::make_shared(uds_path))}, config) {} @@ -541,22 +541,22 @@ makeListenSocket(const FakeUpstreamConfig& config, FakeUpstream::FakeUpstream(uint32_t port, Network::Address::IpVersion version, const FakeUpstreamConfig& config) - : FakeUpstream(Network::Test::createRawBufferSocketFactory(), + : FakeUpstream(Network::Test::createRawBufferDownstreamSocketFactory(), makeListenSocket(config, makeAddress(port, version)), config) {} -FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, +FakeUpstream::FakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, const Network::Address::InstanceConstSharedPtr& address, const FakeUpstreamConfig& config) : FakeUpstream(std::move(transport_socket_factory), makeListenSocket(config, address), config) { } -FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, +FakeUpstream::FakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, uint32_t port, Network::Address::IpVersion version, const FakeUpstreamConfig& config) : FakeUpstream(std::move(transport_socket_factory), makeListenSocket(config, makeAddress(port, version)), config) {} -FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, +FakeUpstream::FakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, Network::SocketPtr&& listen_socket, const FakeUpstreamConfig& config) : http_type_(config.upstream_protocol_), http2_options_(config.http2_options_), http3_options_(config.http3_options_), quic_options_(config.quic_options_), diff --git a/test/integration/fake_upstream.h b/test/integration/fake_upstream.h index 1f49cfacea48d..63f7fd81a4670 100644 --- a/test/integration/fake_upstream.h +++ b/test/integration/fake_upstream.h @@ -608,7 +608,7 @@ class FakeUpstream : Logger::Loggable, FakeUpstream(const std::string& uds_path, const FakeUpstreamConfig& config); // Creates a fake upstream bound to the specified |address|. - FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, + FakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, const Network::Address::InstanceConstSharedPtr& address, const FakeUpstreamConfig& config); @@ -616,8 +616,9 @@ class FakeUpstream : Logger::Loggable, FakeUpstream(uint32_t port, Network::Address::IpVersion version, const FakeUpstreamConfig& config); - FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, uint32_t port, - Network::Address::IpVersion version, const FakeUpstreamConfig& config); + FakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, + uint32_t port, Network::Address::IpVersion version, + const FakeUpstreamConfig& config); ~FakeUpstream() override; Http::CodecType httpType() { return http_type_; } @@ -712,7 +713,7 @@ class FakeUpstream : Logger::Loggable, const Http::CodecType http_type_; private: - FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, + FakeUpstream(Network::DownstreamTransportSocketFactoryPtr&& transport_socket_factory, Network::SocketPtr&& connection, const FakeUpstreamConfig& config); class FakeListenSocketFactory : public Network::ListenSocketFactory { diff --git a/test/integration/http_integration.h b/test/integration/http_integration.h index 5f688940db5bd..cd0d42f792575 100644 --- a/test/integration/http_integration.h +++ b/test/integration/http_integration.h @@ -300,7 +300,7 @@ class HttpIntegrationTest : public BaseIntegrationTest { // Prefix listener stat with IP:port, including IP version dependent loopback address. std::string listenerStatPrefix(const std::string& stat_name); - Network::TransportSocketFactoryPtr quic_transport_socket_factory_; + Network::UpstreamTransportSocketFactoryPtr quic_transport_socket_factory_; // Must outlive |codec_client_| because it may not close connection till the end of its life // scope. std::unique_ptr quic_connection_persistent_info_; diff --git a/test/integration/multiplexed_upstream_integration_test.cc b/test/integration/multiplexed_upstream_integration_test.cc index eb62dad34c525..75a3019aa3bbb 100644 --- a/test/integration/multiplexed_upstream_integration_test.cc +++ b/test/integration/multiplexed_upstream_integration_test.cc @@ -852,14 +852,13 @@ class QuicFailHandshakeCryptoServerStreamFactory } std::string name() const override { return "envoy.quic.crypto_stream.server.fail_handshake"; } - std::unique_ptr - createEnvoyQuicCryptoServerStream(const quic::QuicCryptoServerConfig* crypto_config, - quic::QuicCompressedCertsCache* /*compressed_certs_cache*/, - quic::QuicSession* session, - quic::QuicCryptoServerStreamBase::Helper* /*helper*/, - Envoy::OptRef - /*transport_socket_factory*/, - Envoy::Event::Dispatcher& /*dispatcher*/) override { + std::unique_ptr createEnvoyQuicCryptoServerStream( + const quic::QuicCryptoServerConfig* crypto_config, + quic::QuicCompressedCertsCache* /*compressed_certs_cache*/, quic::QuicSession* session, + quic::QuicCryptoServerStreamBase::Helper* /*helper*/, + Envoy::OptRef + /*transport_socket_factory*/, + Envoy::Event::Dispatcher& /*dispatcher*/) override { ASSERT(session->connection()->version().handshake_protocol == quic::PROTOCOL_TLS1_3); return std::make_unique(session, crypto_config, fail_handshake_); } diff --git a/test/integration/sds_dynamic_integration_test.cc b/test/integration/sds_dynamic_integration_test.cc index b98674d637211..07c0002c5b224 100644 --- a/test/integration/sds_dynamic_integration_test.cc +++ b/test/integration/sds_dynamic_integration_test.cc @@ -315,7 +315,7 @@ version_info: "0" } protected: - Network::TransportSocketFactoryPtr client_ssl_ctx_; + Network::UpstreamTransportSocketFactoryPtr client_ssl_ctx_; bool dual_cert_{false}; }; @@ -591,7 +591,7 @@ class SdsDynamicDownstreamCertValidationContextTest : public SdsDynamicDownstrea create_xds_upstream_ = true; } - Network::TransportSocketFactoryPtr createUpstreamSslContext() { + Network::DownstreamTransportSocketFactoryPtr createUpstreamSslContext() { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; auto* common_tls_context = tls_context.mutable_common_tls_context(); auto* tls_certificate = common_tls_context->add_tls_certificates(); diff --git a/test/integration/sds_static_integration_test.cc b/test/integration/sds_static_integration_test.cc index 67cd1499d2485..24a4efe1ae36a 100644 --- a/test/integration/sds_static_integration_test.cc +++ b/test/integration/sds_static_integration_test.cc @@ -91,7 +91,7 @@ class SdsStaticDownstreamIntegrationTest private: Extensions::TransportSockets::Tls::ContextManagerImpl context_manager_{timeSystem()}; - Network::TransportSocketFactoryPtr client_ssl_ctx_; + Network::UpstreamTransportSocketFactoryPtr client_ssl_ctx_; }; INSTANTIATE_TEST_SUITE_P(IpVersions, SdsStaticDownstreamIntegrationTest, diff --git a/test/integration/ssl_utility.cc b/test/integration/ssl_utility.cc index 1c01509071bac..cfc7f12f9eaeb 100644 --- a/test/integration/ssl_utility.cc +++ b/test/integration/ssl_utility.cc @@ -103,7 +103,7 @@ void initializeUpstreamTlsContextConfig( common_context->mutable_tls_params()->set_tls_maximum_protocol_version(options.tls_version_); } -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr createClientSslTransportSocketFactory(const ClientSslTransportOptions& options, ContextManager& context_manager, Api::Api& api) { envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; @@ -114,13 +114,13 @@ createClientSslTransportSocketFactory(const ClientSslTransportOptions& options, auto cfg = std::make_unique( tls_context, options.sigalgs_, mock_factory_ctx); static auto* client_stats_store = new Stats::TestIsolatedStoreImpl(); - return Network::TransportSocketFactoryPtr{ + return Network::UpstreamTransportSocketFactoryPtr{ new Extensions::TransportSockets::Tls::ClientSslSocketFactory(std::move(cfg), context_manager, *client_stats_store)}; } -Network::TransportSocketFactoryPtr createUpstreamSslContext(ContextManager& context_manager, - Api::Api& api, bool use_http3) { +Network::DownstreamTransportSocketFactoryPtr +createUpstreamSslContext(ContextManager& context_manager, Api::Api& api, bool use_http3) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; ConfigHelper::initializeTls({}, *tls_context.mutable_common_tls_context()); @@ -144,7 +144,7 @@ Network::TransportSocketFactoryPtr createUpstreamSslContext(ContextManager& cont return config_factory.createTransportSocketFactory(quic_config, mock_factory_ctx, server_names); } -Network::TransportSocketFactoryPtr createFakeUpstreamSslContext( +Network::DownstreamTransportSocketFactoryPtr createFakeUpstreamSslContext( const std::string& upstream_cert_name, ContextManager& context_manager, Server::Configuration::TransportSocketFactoryContext& factory_context) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; diff --git a/test/integration/ssl_utility.h b/test/integration/ssl_utility.h index 7e988094b775c..a4ac078f413d6 100644 --- a/test/integration/ssl_utility.h +++ b/test/integration/ssl_utility.h @@ -73,14 +73,14 @@ void initializeUpstreamTlsContextConfig( const ClientSslTransportOptions& options, envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext& tls_context); -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr createClientSslTransportSocketFactory(const ClientSslTransportOptions& options, ContextManager& context_manager, Api::Api& api); -Network::TransportSocketFactoryPtr createUpstreamSslContext(ContextManager& context_manager, - Api::Api& api, bool use_http3 = false); +Network::DownstreamTransportSocketFactoryPtr +createUpstreamSslContext(ContextManager& context_manager, Api::Api& api, bool use_http3 = false); -Network::TransportSocketFactoryPtr +Network::DownstreamTransportSocketFactoryPtr createFakeUpstreamSslContext(const std::string& upstream_cert_name, ContextManager& context_manager, Server::Configuration::TransportSocketFactoryContext& factory_context); diff --git a/test/integration/tcp_proxy_integration_test.cc b/test/integration/tcp_proxy_integration_test.cc index 66d629f2ede3a..3835c7629e0ba 100644 --- a/test/integration/tcp_proxy_integration_test.cc +++ b/test/integration/tcp_proxy_integration_test.cc @@ -1436,9 +1436,9 @@ class MysqlIntegrationTest : public TcpProxyIntegrationTest { public: void createUpstreams() override { for (uint32_t i = 0; i < fake_upstreams_count_; ++i) { - Network::TransportSocketFactoryPtr factory = + Network::DownstreamTransportSocketFactoryPtr factory = upstream_tls_ ? createUpstreamTlsContext(upstreamConfig()) - : Network::Test::createRawBufferSocketFactory(); + : Network::Test::createRawBufferDownstreamSocketFactory(); auto endpoint = upstream_address_fn_(i); fake_upstreams_.emplace_back( new FakeMysqlUpstream(std::move(factory), endpoint, upstreamConfig())); diff --git a/test/integration/tcp_proxy_integration_test.h b/test/integration/tcp_proxy_integration_test.h index 7aeb7d137c65c..dcdba183c5a4d 100644 --- a/test/integration/tcp_proxy_integration_test.h +++ b/test/integration/tcp_proxy_integration_test.h @@ -35,7 +35,7 @@ class TcpProxySslIntegrationTest : public TcpProxyIntegrationTest { const std::string& data_to_send_downstream); std::unique_ptr context_manager_; - Network::TransportSocketFactoryPtr context_; + Network::UpstreamTransportSocketFactoryPtr context_; ConnectionStatusCallbacks connect_callbacks_; MockWatermarkBuffer* client_write_buffer_; std::shared_ptr payload_reader_; diff --git a/test/integration/transport_socket_match_integration_test.cc b/test/integration/transport_socket_match_integration_test.cc index 6b97546e038b1..2fcba4732bda6 100644 --- a/test/integration/transport_socket_match_integration_test.cc +++ b/test/integration/transport_socket_match_integration_test.cc @@ -122,7 +122,7 @@ name: "tls_socket" endpoint->ip()->version(), upstreamConfig(), false)); } else { fake_upstreams_.emplace_back(new AutonomousUpstream( - Network::Test::createRawBufferSocketFactory(), endpoint->ip()->port(), + Network::Test::createRawBufferDownstreamSocketFactory(), endpoint->ip()->port(), endpoint->ip()->version(), upstreamConfig(), false)); } } diff --git a/test/integration/upstream_filter_state_integration_test.cc b/test/integration/upstream_filter_state_integration_test.cc index 9433dc1b07c10..3b87ee8b40148 100644 --- a/test/integration/upstream_filter_state_integration_test.cc +++ b/test/integration/upstream_filter_state_integration_test.cc @@ -58,7 +58,7 @@ class Socket : public Extensions::TransportSockets::PassthroughSocket { class SocketFactory : public Extensions::TransportSockets::PassthroughFactory { public: - SocketFactory(Network::TransportSocketFactoryPtr&& inner_factory) + SocketFactory(Network::UpstreamTransportSocketFactoryPtr&& inner_factory) : PassthroughFactory(std::move(inner_factory)) {} Network::TransportSocketPtr @@ -79,7 +79,7 @@ class SocketConfigFactory : public Server::Configuration::UpstreamTransportSocke return std::make_unique(); } - Network::TransportSocketFactoryPtr createTransportSocketFactory( + Network::UpstreamTransportSocketFactoryPtr createTransportSocketFactory( const Protobuf::Message& config, Server::Configuration::TransportSocketFactoryContext& context) override { const auto& outer_config = diff --git a/test/integration/utility.cc b/test/integration/utility.cc index 37ca2275c8d8b..b7cf9b5d60f59 100644 --- a/test/integration/utility.cc +++ b/test/integration/utility.cc @@ -125,7 +125,7 @@ class TestConnectionCallbacks : public Network::ConnectionCallbacks { bool connected_{false}; }; -Network::TransportSocketFactoryPtr +Network::UpstreamTransportSocketFactoryPtr IntegrationUtil::createQuicUpstreamTransportSocketFactory(Api::Api& api, Stats::Store& store, Ssl::ContextManager& context_manager, const std::string& san_to_match) { @@ -211,7 +211,7 @@ IntegrationUtil::makeSingleRequest(const Network::Address::InstanceConstSharedPt #ifdef ENVOY_ENABLE_QUIC Extensions::TransportSockets::Tls::ContextManagerImpl manager(time_system); - Network::TransportSocketFactoryPtr transport_socket_factory = + Network::UpstreamTransportSocketFactoryPtr transport_socket_factory = createQuicUpstreamTransportSocketFactory(api, mock_stats_store, manager, "spiffe://lyft.com/backend-team"); auto& quic_transport_socket_factory = diff --git a/test/integration/utility.h b/test/integration/utility.h index bc907e8a1d064..262827297a98a 100644 --- a/test/integration/utility.h +++ b/test/integration/utility.h @@ -204,7 +204,7 @@ class IntegrationUtil { * Create transport socket factory for Quic upstream transport socket. * @return TransportSocketFactoryPtr the client transport socket factory. */ - static Network::TransportSocketFactoryPtr + static Network::UpstreamTransportSocketFactoryPtr createQuicUpstreamTransportSocketFactory(Api::Api& api, Stats::Store& store, Ssl::ContextManager& context_manager, const std::string& san_to_match); diff --git a/test/integration/xds_integration_test.cc b/test/integration/xds_integration_test.cc index fdd4ff666d35f..490117984b807 100644 --- a/test/integration/xds_integration_test.cc +++ b/test/integration/xds_integration_test.cc @@ -207,7 +207,7 @@ class LdsInplaceUpdateTcpProxyIntegrationTest } std::unique_ptr context_manager_; - Network::TransportSocketFactoryPtr context_; + Network::UpstreamTransportSocketFactoryPtr context_; testing::NiceMock secret_manager_; bool matcher_; }; @@ -495,7 +495,7 @@ class LdsInplaceUpdateHttpIntegrationTest } std::unique_ptr context_manager_; - Network::TransportSocketFactoryPtr context_; + Network::UpstreamTransportSocketFactoryPtr context_; testing::NiceMock secret_manager_; Network::Address::InstanceConstSharedPtr address_; bool use_default_balancer_{false}; diff --git a/test/integration/xfcc_integration_test.cc b/test/integration/xfcc_integration_test.cc index 95038e5893a9b..87eb80723ce77 100644 --- a/test/integration/xfcc_integration_test.cc +++ b/test/integration/xfcc_integration_test.cc @@ -40,7 +40,7 @@ void XfccIntegrationTest::TearDown() { context_manager_.reset(); } -Network::TransportSocketFactoryPtr XfccIntegrationTest::createClientSslContext(bool mtls) { +Network::UpstreamTransportSocketFactoryPtr XfccIntegrationTest::createClientSslContext(bool mtls) { const std::string yaml_tls = R"EOF( common_tls_context: validation_context: @@ -97,12 +97,12 @@ Network::TransportSocketFactoryPtr XfccIntegrationTest::createClientSslContext(b auto cfg = std::make_unique( config, factory_context_); static auto* client_stats_store = new Stats::TestIsolatedStoreImpl(); - return Network::TransportSocketFactoryPtr{ + return Network::UpstreamTransportSocketFactoryPtr{ new Extensions::TransportSockets::Tls::ClientSslSocketFactory( std::move(cfg), *context_manager_, *client_stats_store)}; } -Network::TransportSocketFactoryPtr XfccIntegrationTest::createUpstreamSslContext() { +Network::DownstreamTransportSocketFactoryPtr XfccIntegrationTest::createUpstreamSslContext() { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; auto* common_tls_context = tls_context.mutable_common_tls_context(); auto* tls_cert = common_tls_context->add_tls_certificates(); diff --git a/test/integration/xfcc_integration_test.h b/test/integration/xfcc_integration_test.h index 1564565dbae72..def7807d94b27 100644 --- a/test/integration/xfcc_integration_test.h +++ b/test/integration/xfcc_integration_test.h @@ -48,8 +48,8 @@ class XfccIntegrationTest : public testing::TestWithParam context_manager_; - Network::TransportSocketFactoryPtr client_tls_ssl_ctx_; - Network::TransportSocketFactoryPtr client_mtls_ssl_ctx_; - Network::TransportSocketFactoryPtr upstream_ssl_ctx_; + Network::UpstreamTransportSocketFactoryPtr client_tls_ssl_ctx_; + Network::UpstreamTransportSocketFactoryPtr client_mtls_ssl_ctx_; + Network::UpstreamTransportSocketFactoryPtr upstream_ssl_ctx_; testing::NiceMock factory_context_; }; } // namespace Xfcc diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 27c5799b00f9a..5270dddf9e0be 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -229,7 +229,7 @@ class MockFilterChain : public DrainableFilterChain { ~MockFilterChain() override; // Network::DrainableFilterChain - MOCK_METHOD(const TransportSocketFactory&, transportSocketFactory, (), (const)); + MOCK_METHOD(const DownstreamTransportSocketFactory&, transportSocketFactory, (), (const)); MOCK_METHOD(std::chrono::milliseconds, transportSocketConnectTimeout, (), (const)); MOCK_METHOD(const std::vector&, networkFilterFactories, (), (const)); MOCK_METHOD(void, startDraining, ()); diff --git a/test/mocks/network/transport_socket.cc b/test/mocks/network/transport_socket.cc index 5fb7a916ad04a..a72b2f008bb85 100644 --- a/test/mocks/network/transport_socket.cc +++ b/test/mocks/network/transport_socket.cc @@ -22,5 +22,8 @@ MockTransportSocket::~MockTransportSocket() = default; MockTransportSocketFactory::MockTransportSocketFactory() = default; MockTransportSocketFactory::~MockTransportSocketFactory() = default; +MockDownstreamTransportSocketFactory::MockDownstreamTransportSocketFactory() = default; +MockDownstreamTransportSocketFactory::~MockDownstreamTransportSocketFactory() = default; + } // namespace Network } // namespace Envoy diff --git a/test/mocks/network/transport_socket.h b/test/mocks/network/transport_socket.h index 87ae7ad4e4e41..df282394ddc14 100644 --- a/test/mocks/network/transport_socket.h +++ b/test/mocks/network/transport_socket.h @@ -33,7 +33,7 @@ class MockTransportSocket : public TransportSocket { TransportSocketCallbacks* callbacks_{}; }; -class MockTransportSocketFactory : public TransportSocketFactory { +class MockTransportSocketFactory : public UpstreamTransportSocketFactory { public: MockTransportSocketFactory(); ~MockTransportSocketFactory() override; @@ -47,5 +47,14 @@ class MockTransportSocketFactory : public TransportSocketFactory { (std::vector & key, TransportSocketOptionsConstSharedPtr options), (const)); }; +class MockDownstreamTransportSocketFactory : public DownstreamTransportSocketFactory { +public: + MockDownstreamTransportSocketFactory(); + ~MockDownstreamTransportSocketFactory() override; + + MOCK_METHOD(bool, implementsSecureTransport, (), (const)); + MOCK_METHOD(TransportSocketPtr, createDownstreamTransportSocket, (), (const)); +}; + } // namespace Network } // namespace Envoy diff --git a/test/mocks/upstream/host.h b/test/mocks/upstream/host.h index 89d3bc29f287d..8a2eade1fa5a2 100644 --- a/test/mocks/upstream/host.h +++ b/test/mocks/upstream/host.h @@ -96,7 +96,7 @@ class MockHostDescription : public HostDescription { MOCK_METHOD(HealthCheckHostMonitor&, healthChecker, (), (const)); MOCK_METHOD(const std::string&, hostnameForHealthChecks, (), (const)); MOCK_METHOD(const std::string&, hostname, (), (const)); - MOCK_METHOD(Network::TransportSocketFactory&, transportSocketFactory, (), (const)); + MOCK_METHOD(Network::UpstreamTransportSocketFactory&, transportSocketFactory, (), (const)); MOCK_METHOD(HostStats&, stats, (), (const)); MOCK_METHOD(LoadMetricStats&, loadMetricStats, (), (const)); MOCK_METHOD(const envoy::config::core::v3::Locality&, locality, (), (const)); @@ -114,7 +114,7 @@ class MockHostDescription : public HostDescription { Network::Address::InstanceConstSharedPtr address_; testing::NiceMock outlier_detector_; testing::NiceMock health_checker_; - Network::TransportSocketFactoryPtr socket_factory_; + Network::UpstreamTransportSocketFactoryPtr socket_factory_; testing::NiceMock cluster_; HostStats stats_; LoadMetricStatsImpl load_metric_stats_; @@ -188,7 +188,7 @@ class MockHost : public Host { MOCK_METHOD(Host::Health, health, (), (const)); MOCK_METHOD(const std::string&, hostnameForHealthChecks, (), (const)); MOCK_METHOD(const std::string&, hostname, (), (const)); - MOCK_METHOD(Network::TransportSocketFactory&, transportSocketFactory, (), (const)); + MOCK_METHOD(Network::UpstreamTransportSocketFactory&, transportSocketFactory, (), (const)); MOCK_METHOD(Outlier::DetectorHostMonitor&, outlierDetector, (), (const)); MOCK_METHOD(void, setHealthChecker_, (HealthCheckHostMonitorPtr & health_checker)); MOCK_METHOD(void, setOutlierDetector_, (Outlier::DetectorHostMonitorPtr & outlier_detector)); @@ -205,7 +205,7 @@ class MockHost : public Host { MOCK_METHOD(MonotonicTime, creationTime, (), (const)); testing::NiceMock cluster_; - Network::TransportSocketFactoryPtr socket_factory_; + Network::UpstreamTransportSocketFactoryPtr socket_factory_; testing::NiceMock outlier_detector_; HostStats stats_; LoadMetricStatsImpl load_metric_stats_; diff --git a/test/mocks/upstream/transport_socket_match.cc b/test/mocks/upstream/transport_socket_match.cc index 0f12148f90d3a..f0c598aadc6f4 100644 --- a/test/mocks/upstream/transport_socket_match.cc +++ b/test/mocks/upstream/transport_socket_match.cc @@ -11,7 +11,8 @@ namespace Upstream { MockTransportSocketMatcher::MockTransportSocketMatcher() : MockTransportSocketMatcher(std::make_unique()) {} -MockTransportSocketMatcher::MockTransportSocketMatcher(Network::TransportSocketFactoryPtr factory) +MockTransportSocketMatcher::MockTransportSocketMatcher( + Network::UpstreamTransportSocketFactoryPtr factory) : socket_factory_(std::move(factory)), stats_({ALL_TRANSPORT_SOCKET_MATCH_STATS(POOL_COUNTER_PREFIX(stats_store_, "test"))}) { ON_CALL(*this, resolve(_)) diff --git a/test/mocks/upstream/transport_socket_match.h b/test/mocks/upstream/transport_socket_match.h index 5dd6b52758e8f..2fcdbbb16c55a 100644 --- a/test/mocks/upstream/transport_socket_match.h +++ b/test/mocks/upstream/transport_socket_match.h @@ -15,13 +15,13 @@ namespace Upstream { class MockTransportSocketMatcher : public TransportSocketMatcher { public: MockTransportSocketMatcher(); - MockTransportSocketMatcher(Network::TransportSocketFactoryPtr default_factory); + MockTransportSocketMatcher(Network::UpstreamTransportSocketFactoryPtr default_factory); ~MockTransportSocketMatcher() override; MOCK_METHOD(TransportSocketMatcher::MatchData, resolve, (const envoy::config::core::v3::Metadata*), (const)); MOCK_METHOD(bool, allMatchesSupportAlpn, (), (const)); - Network::TransportSocketFactoryPtr socket_factory_; + Network::UpstreamTransportSocketFactoryPtr socket_factory_; Stats::TestUtil::TestStore stats_store_; TransportSocketMatchStats stats_; }; diff --git a/test/server/active_tcp_listener_test.cc b/test/server/active_tcp_listener_test.cc index 159b04338efaa..a053999643188 100644 --- a/test/server/active_tcp_listener_test.cc +++ b/test/server/active_tcp_listener_test.cc @@ -555,7 +555,7 @@ TEST_F(ActiveTcpListenerTest, RedirectedRebalancer) { ReturnRef(*active_listener2))); auto filter_factory_callback = std::make_shared>(); - auto transport_socket_factory = Network::Test::createRawBufferSocketFactory(); + auto transport_socket_factory = Network::Test::createRawBufferDownstreamSocketFactory(); filter_chain_ = std::make_shared>(); EXPECT_CALL(conn_handler_, incNumConnections()); diff --git a/test/server/connection_handler_test.cc b/test/server/connection_handler_test.cc index fb13878403f94..1e5a9c55c2666 100644 --- a/test/server/connection_handler_test.cc +++ b/test/server/connection_handler_test.cc @@ -56,8 +56,8 @@ class ConnectionHandlerTest : public testing::Test, protected Logger::Loggable>()), access_log_(std::make_shared()) { ON_CALL(*filter_chain_, transportSocketFactory) - .WillByDefault(ReturnPointee(std::shared_ptr{ - Network::Test::createRawBufferSocketFactory()})); + .WillByDefault(ReturnPointee(std::shared_ptr{ + Network::Test::createRawBufferDownstreamSocketFactory()})); ON_CALL(*filter_chain_, networkFilterFactories) .WillByDefault(ReturnPointee(std::make_shared>())); ON_CALL(*listener_filter_matcher_, matches(_)).WillByDefault(Return(false)); diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index c88aab06d09d2..c0c481da944f2 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -2639,7 +2639,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationP filter_chain = findFilterChain(8080, "127.0.0.1", "", "tls", {}, "8.8.8.8", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -2711,7 +2711,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDirectSource filter_chain = findFilterChain(1234, "1.2.3.4", "", "tls", {}, "8.8.8.8", 111, "127.0.0.1"); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -2783,7 +2783,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationI filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "8.8.8.8", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -2854,7 +2854,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithServerNamesM findFilterChain(1234, "127.0.0.1", "server1.example.com", "tls", {}, "8.8.8.8", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -2916,7 +2916,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithTransportPro filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "8.8.8.8", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -2983,7 +2983,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithApplicationP 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -3049,7 +3049,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithSourceTypeMa 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -3063,7 +3063,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithSourceTypeMa "/tmp/test.sock", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -3136,7 +3136,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithSourceIpMatc 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -3274,7 +3274,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithSourcePortMa auto filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 100); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -3407,7 +3407,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithSourceTyp filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -3421,7 +3421,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithSourceTyp 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->ssl()->uriSanLocalCertificate(); EXPECT_EQ(uri[0], "spiffe://lyft.com/test-team"); @@ -3430,7 +3430,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithSourceTyp filter_chain = findFilterChain(1234, "8.8.8.8", "", "tls", {}, "4.4.4.4", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -3521,7 +3521,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati auto filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->ssl()->uriSanLocalCertificate(); @@ -3531,7 +3531,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(8080, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -3541,7 +3541,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(8081, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -3551,7 +3551,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(0, "/tmp/test.sock", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); uri = ssl_socket->ssl()->uriSanLocalCertificate(); EXPECT_EQ(uri[0], "spiffe://lyft.com/test-team"); @@ -3650,7 +3650,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati auto filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->ssl()->uriSanLocalCertificate(); @@ -3660,7 +3660,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); uri = ssl_socket->ssl()->uriSanLocalCertificate(); EXPECT_EQ(uri[0], "spiffe://lyft.com/test-team"); @@ -3669,7 +3669,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(1234, "192.168.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -3679,7 +3679,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(1234, "192.168.1.1", "", "tls", {}, "192.168.1.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -3689,7 +3689,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(0, "/tmp/test.sock", "", "tls", {}, "/tmp/test.sock", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); uri = ssl_socket->ssl()->uriSanLocalCertificate(); EXPECT_EQ(uri[0], "spiffe://lyft.com/test-team"); @@ -3788,7 +3788,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDirectSou auto filter_chain = findFilterChain(1234, "/uds_1", "", "tls", {}, "/uds_2", 111, "/uds_3"); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->ssl()->uriSanLocalCertificate(); @@ -3798,7 +3798,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDirectSou filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111, "127.0.0.1"); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); uri = ssl_socket->ssl()->uriSanLocalCertificate(); EXPECT_EQ(uri[0], "spiffe://lyft.com/test-team"); @@ -3807,7 +3807,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDirectSou filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111, "192.168.0.1"); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -3817,7 +3817,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDirectSou filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111, "192.168.1.1"); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -3887,7 +3887,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam auto filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->ssl()->uriSanLocalCertificate(); @@ -3898,7 +3898,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam findFilterChain(1234, "127.0.0.1", "server1.example.com", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -3909,7 +3909,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam findFilterChain(1234, "127.0.0.1", "server2.example.com", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -3920,7 +3920,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam findFilterChain(1234, "127.0.0.1", "www.wildcard.com", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -3991,7 +3991,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithTransport filter_chain = findFilterChain(1234, "127.0.0.1", "", "tls", {}, "127.0.0.1", 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -4066,7 +4066,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithApplicati 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); @@ -4133,7 +4133,7 @@ TEST_P(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithMultipleR 111); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); + auto transport_socket = filter_chain->transportSocketFactory().createDownstreamTransportSocket(); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->ssl()->dnsSansLocalCertificate(); diff --git a/test/test_common/network_utility.cc b/test/test_common/network_utility.cc index dec1dceca3ee5..d4bc42064a3e5 100644 --- a/test/test_common/network_utility.cc +++ b/test/test_common/network_utility.cc @@ -184,17 +184,21 @@ bindFreeLoopbackPort(Address::IpVersion version, Socket::Type type, bool reuse_p TransportSocketPtr createRawBufferSocket() { return std::make_unique(); } -TransportSocketFactoryPtr createRawBufferSocketFactory() { +UpstreamTransportSocketFactoryPtr createRawBufferSocketFactory() { + return std::make_unique(); +} + +DownstreamTransportSocketFactoryPtr createRawBufferDownstreamSocketFactory() { return std::make_unique(); } const Network::FilterChainSharedPtr -createEmptyFilterChain(TransportSocketFactoryPtr&& transport_socket_factory) { +createEmptyFilterChain(DownstreamTransportSocketFactoryPtr&& transport_socket_factory) { return std::make_shared(std::move(transport_socket_factory)); } const Network::FilterChainSharedPtr createEmptyFilterChainWithRawBufferSockets() { - return createEmptyFilterChain(createRawBufferSocketFactory()); + return createEmptyFilterChain(createRawBufferDownstreamSocketFactory()); } namespace { diff --git a/test/test_common/network_utility.h b/test/test_common/network_utility.h index 58e40374b0316..a632aa7f7a7d9 100644 --- a/test/test_common/network_utility.h +++ b/test/test_common/network_utility.h @@ -138,10 +138,15 @@ TransportSocketPtr createRawBufferSocket(); /** * Create a transport socket factory for testing purposes. - * @return TransportSocketFactoryPtr the transport socket factory to use with a cluster or a - * listener. + * @return TransportSocketFactoryPtr the transport socket factory to use with a cluster */ -TransportSocketFactoryPtr createRawBufferSocketFactory(); +UpstreamTransportSocketFactoryPtr createRawBufferSocketFactory(); + +/** + * Create a transport socket factory for testing purposes. + * @return TransportSocketFactoryPtr the transport socket factory to use with a listener. + */ +DownstreamTransportSocketFactoryPtr createRawBufferDownstreamSocketFactory(); /** * Implementation of Network::FilterChain with empty filter chain, but pluggable transport socket @@ -149,11 +154,11 @@ TransportSocketFactoryPtr createRawBufferSocketFactory(); */ class EmptyFilterChain : public FilterChain { public: - EmptyFilterChain(TransportSocketFactoryPtr&& transport_socket_factory) + EmptyFilterChain(DownstreamTransportSocketFactoryPtr&& transport_socket_factory) : transport_socket_factory_(std::move(transport_socket_factory)) {} // Network::FilterChain - const TransportSocketFactory& transportSocketFactory() const override { + const DownstreamTransportSocketFactory& transportSocketFactory() const override { return *transport_socket_factory_; } @@ -168,7 +173,7 @@ class EmptyFilterChain : public FilterChain { absl::string_view name() const override { return "EmptyFilterChain"; } private: - const TransportSocketFactoryPtr transport_socket_factory_; + const DownstreamTransportSocketFactoryPtr transport_socket_factory_; const std::vector empty_network_filter_factory_{}; }; @@ -178,7 +183,7 @@ class EmptyFilterChain : public FilterChain { * @return const FilterChainSharedPtr filter chain. */ const FilterChainSharedPtr -createEmptyFilterChain(TransportSocketFactoryPtr&& transport_socket_factory); +createEmptyFilterChain(DownstreamTransportSocketFactoryPtr&& transport_socket_factory); /** * Create an empty filter chain creating raw buffer sockets for testing purposes.