diff --git a/envoy/network/BUILD b/envoy/network/BUILD index b24c61502c9a..6fc727974cf2 100644 --- a/envoy/network/BUILD +++ b/envoy/network/BUILD @@ -157,6 +157,7 @@ envoy_cc_library( ":proxy_protocol_options_lib", "//envoy/buffer:buffer_interface", "//envoy/ssl:connection_interface", + "//envoy/stream_info:filter_state_interface", ], ) diff --git a/envoy/network/transport_socket.h b/envoy/network/transport_socket.h index 68bf5e2448e7..f4b495f3f8ed 100644 --- a/envoy/network/transport_socket.h +++ b/envoy/network/transport_socket.h @@ -8,6 +8,7 @@ #include "envoy/network/post_io_action.h" #include "envoy/network/proxy_protocol.h" #include "envoy/ssl/connection.h" +#include "envoy/stream_info/filter_state.h" #include "absl/types/optional.h" @@ -220,13 +221,9 @@ class TransportSocketOptions { virtual absl::optional proxyProtocolOptions() const PURE; /** - * @param key supplies a vector of bytes to which the option should append hash key data that will - * be used to separate connections based on the option. Any data already in the key vector - * must not be modified. - * @param factory supplies the factor which will be used for creating the transport socket. + * @return filter state from the downstream request or connection. */ - virtual void hashKey(std::vector& key, - const Network::TransportSocketFactory& factory) const PURE; + virtual const StreamInfo::FilterStateSharedPtr& filterState() const PURE; }; using TransportSocketOptionsConstSharedPtr = std::shared_ptr; @@ -250,16 +247,20 @@ class TransportSocketFactory { virtual TransportSocketPtr createTransportSocket(TransportSocketOptionsConstSharedPtr options) const PURE; - /** - * @return bool whether the transport socket will use proxy protocol options. - */ - virtual bool usesProxyProtocolOptions() const PURE; - /** * Returns true if the transport socket created by this factory supports some form of ALPN * negotiation. */ virtual bool supportsAlpn() const { return false; } + + /** + * @param key supplies a vector of bytes to which the option should append hash key data that will + * be used to separate connections based on the option. Any data already in the key vector + * must not be modified. + * @param options supplies the transport socket options. + */ + virtual void hashKey(std::vector& key, + TransportSocketOptionsConstSharedPtr options) const PURE; }; using TransportSocketFactoryPtr = std::unique_ptr; diff --git a/source/common/network/BUILD b/source/common/network/BUILD index 07e7c3783814..ba05c1bfa4ad 100644 --- a/source/common/network/BUILD +++ b/source/common/network/BUILD @@ -276,6 +276,7 @@ envoy_cc_library( "//source/common/buffer:buffer_lib", "//source/common/common:empty_string", "//source/common/http:headers_lib", + "//source/common/network:transport_socket_options_lib", ], ) diff --git a/source/common/network/raw_buffer_socket.cc b/source/common/network/raw_buffer_socket.cc index 447da05eeb53..a35dd49e7093 100644 --- a/source/common/network/raw_buffer_socket.cc +++ b/source/common/network/raw_buffer_socket.cc @@ -92,5 +92,6 @@ RawBufferSocketFactory::createTransportSocket(TransportSocketOptionsConstSharedP } bool RawBufferSocketFactory::implementsSecureTransport() const { return false; } + } // namespace Network } // namespace Envoy diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index b2f94be856bf..46ecccc5c4c6 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -5,6 +5,7 @@ #include "envoy/network/transport_socket.h" #include "source/common/common/logger.h" +#include "source/common/network/transport_socket_options_impl.h" namespace Envoy { namespace Network { @@ -29,13 +30,12 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggable& key, - const Network::TransportSocketFactory& factory) { - const auto& server_name_overide = options.serverNameOverride(); + +void CommonTransportSocketFactory::hashKey(std::vector& key, + TransportSocketOptionsConstSharedPtr options) const { + if (!options) { + return; + } + const auto& server_name_overide = options->serverNameOverride(); if (server_name_overide.has_value()) { pushScalarToByteVector(StringUtil::CaseInsensitiveHash()(server_name_overide.value()), key); } - const auto& verify_san_list = options.verifySubjectAltNameListOverride(); + const auto& verify_san_list = options->verifySubjectAltNameListOverride(); for (const auto& san : verify_san_list) { pushScalarToByteVector(StringUtil::CaseInsensitiveHash()(san), key); } - const auto& alpn_list = options.applicationProtocolListOverride(); + const auto& alpn_list = options->applicationProtocolListOverride(); for (const auto& protocol : alpn_list) { pushScalarToByteVector(StringUtil::CaseInsensitiveHash()(protocol), key); } - const auto& alpn_fallback = options.applicationProtocolFallback(); + const auto& alpn_fallback = options->applicationProtocolFallback(); for (const auto& protocol : alpn_fallback) { pushScalarToByteVector(StringUtil::CaseInsensitiveHash()(protocol), key); } - - // Proxy protocol options should only be included in the hash if the upstream - // socket intends to use them. - const auto& proxy_protocol_options = options.proxyProtocolOptions(); - if (proxy_protocol_options.has_value() && factory.usesProxyProtocolOptions()) { - pushScalarToByteVector( - StringUtil::CaseInsensitiveHash()(proxy_protocol_options.value().asStringForHash()), key); - } } -} // namespace -void AlpnDecoratingTransportSocketOptions::hashKey( - std::vector& key, const Network::TransportSocketFactory& factory) const { - commonHashKey(*this, key, factory); -} - -void TransportSocketOptionsImpl::hashKey(std::vector& key, - const Network::TransportSocketFactory& factory) const { - commonHashKey(*this, key, factory); -} - -TransportSocketOptionsConstSharedPtr -TransportSocketOptionsUtility::fromFilterState(const StreamInfo::FilterState& filter_state) { +TransportSocketOptionsConstSharedPtr TransportSocketOptionsUtility::fromFilterState( + const StreamInfo::FilterStateSharedPtr& filter_state) { + if (!filter_state) { + return nullptr; + } absl::string_view server_name; std::vector application_protocols; std::vector subject_alt_names; std::vector alpn_fallback; absl::optional proxy_protocol_options; - bool needs_transport_socket_options = false; - if (auto typed_data = filter_state.getDataReadOnly(UpstreamServerName::key()); + if (auto typed_data = + filter_state->getDataReadOnly(UpstreamServerName::key()); typed_data != nullptr) { server_name = typed_data->value(); - needs_transport_socket_options = true; } - if (auto typed_data = filter_state.getDataReadOnly( + if (auto typed_data = filter_state->getDataReadOnly( Network::ApplicationProtocols::key()); typed_data != nullptr) { application_protocols = typed_data->value(); - needs_transport_socket_options = true; } if (auto typed_data = - filter_state.getDataReadOnly(UpstreamSubjectAltNames::key()); + filter_state->getDataReadOnly(UpstreamSubjectAltNames::key()); typed_data != nullptr) { subject_alt_names = typed_data->value(); - needs_transport_socket_options = true; } if (auto typed_data = - filter_state.getDataReadOnly(ProxyProtocolFilterState::key()); + filter_state->getDataReadOnly(ProxyProtocolFilterState::key()); typed_data != nullptr) { proxy_protocol_options.emplace(typed_data->value()); - needs_transport_socket_options = true; } - if (needs_transport_socket_options) { - return std::make_shared( - server_name, std::move(subject_alt_names), std::move(application_protocols), - std::move(alpn_fallback), proxy_protocol_options); - } else { - return nullptr; - } + return std::make_shared( + server_name, std::move(subject_alt_names), std::move(application_protocols), + std::move(alpn_fallback), proxy_protocol_options, filter_state); } } // namespace Network diff --git a/source/common/network/transport_socket_options_impl.h b/source/common/network/transport_socket_options_impl.h index 6512a29c03c5..34029dcbe7e4 100644 --- a/source/common/network/transport_socket_options_impl.h +++ b/source/common/network/transport_socket_options_impl.h @@ -29,8 +29,9 @@ class AlpnDecoratingTransportSocketOptions : public TransportSocketOptions { absl::optional proxyProtocolOptions() const override { return inner_options_->proxyProtocolOptions(); } - void hashKey(std::vector& key, - const Network::TransportSocketFactory& factory) const override; + const StreamInfo::FilterStateSharedPtr& filterState() const override { + return inner_options_->filterState(); + } private: const std::vector alpn_fallback_; @@ -43,13 +44,14 @@ class TransportSocketOptionsImpl : public TransportSocketOptions { absl::string_view override_server_name = "", std::vector&& override_verify_san_list = {}, std::vector&& override_alpn = {}, std::vector&& fallback_alpn = {}, - absl::optional proxy_proto_options = absl::nullopt) + absl::optional proxy_proto_options = absl::nullopt, + const StreamInfo::FilterStateSharedPtr filter_state = nullptr) : override_server_name_(override_server_name.empty() ? absl::nullopt : absl::optional(override_server_name)), override_verify_san_list_{std::move(override_verify_san_list)}, override_alpn_list_{std::move(override_alpn)}, alpn_fallback_{std::move(fallback_alpn)}, - proxy_protocol_options_(proxy_proto_options) {} + proxy_protocol_options_(proxy_proto_options), filter_state_(filter_state) {} // Network::TransportSocketOptions const absl::optional& serverNameOverride() const override { @@ -67,8 +69,7 @@ class TransportSocketOptionsImpl : public TransportSocketOptions { absl::optional proxyProtocolOptions() const override { return proxy_protocol_options_; } - void hashKey(std::vector& key, - const Network::TransportSocketFactory& factory) const override; + const StreamInfo::FilterStateSharedPtr& filterState() const override { return filter_state_; } private: const absl::optional override_server_name_; @@ -76,6 +77,7 @@ class TransportSocketOptionsImpl : public TransportSocketOptions { const std::vector override_alpn_list_; const std::vector alpn_fallback_; const absl::optional proxy_protocol_options_; + const StreamInfo::FilterStateSharedPtr filter_state_; }; class TransportSocketOptionsUtility { @@ -87,7 +89,16 @@ class TransportSocketOptionsUtility { * nullptr if nothing is in the filter state. */ static TransportSocketOptionsConstSharedPtr - fromFilterState(const StreamInfo::FilterState& stream_info); + fromFilterState(const StreamInfo::FilterStateSharedPtr& stream_info); +}; + +class CommonTransportSocketFactory : public TransportSocketFactory { +public: + /** + * Compute the generic hash key from the transport socket options. + */ + void hashKey(std::vector& key, + TransportSocketOptionsConstSharedPtr options) const override; }; } // namespace Network diff --git a/source/common/quic/BUILD b/source/common/quic/BUILD index 121e1ae536e6..99718f63d8d9 100644 --- a/source/common/quic/BUILD +++ b/source/common/quic/BUILD @@ -410,6 +410,7 @@ envoy_cc_library( "//envoy/server:transport_socket_config_interface", "//envoy/ssl:context_config_interface", "//source/common/common:assert_lib", + "//source/common/network:transport_socket_options_lib", "//source/extensions/transport_sockets/tls:context_config_lib", "//source/extensions/transport_sockets/tls:ssl_socket_lib", "@envoy_api//envoy/extensions/transport_sockets/quic/v3:pkg_cc_proto", diff --git a/source/common/quic/quic_transport_socket_factory.h b/source/common/quic/quic_transport_socket_factory.h index 778e08567aed..ffc334480df5 100644 --- a/source/common/quic/quic_transport_socket_factory.h +++ b/source/common/quic/quic_transport_socket_factory.h @@ -6,6 +6,7 @@ #include "envoy/ssl/context_config.h" #include "source/common/common/assert.h" +#include "source/common/network/transport_socket_options_impl.h" #include "source/extensions/transport_sockets/tls/ssl_socket.h" namespace Envoy { @@ -34,7 +35,7 @@ QuicTransportSocketFactoryStats generateStats(Stats::Scope& store, const std::st // socket for QUIC in current implementation. This factory doesn't provides a // transport socket, instead, its derived class provides TLS context config for // server and client. -class QuicTransportSocketFactoryBase : public Network::TransportSocketFactory, +class QuicTransportSocketFactoryBase : public Network::CommonTransportSocketFactory, protected Logger::Loggable { public: QuicTransportSocketFactoryBase(Stats::Scope& store, const std::string& perspective) @@ -49,7 +50,6 @@ class QuicTransportSocketFactoryBase : public Network::TransportSocketFactory, PANIC("not implemented"); } bool implementsSecureTransport() const override { return true; } - bool usesProxyProtocolOptions() const override { return false; } bool supportsAlpn() const override { return true; } protected: diff --git a/source/common/router/router.cc b/source/common/router/router.cc index c77c0b6df08a..0ea7d374a15b 100644 --- a/source/common/router/router.cc +++ b/source/common/router/router.cc @@ -562,7 +562,7 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, } transport_socket_options_ = Network::TransportSocketOptionsUtility::fromFilterState( - *callbacks_->streamInfo().filterState()); + callbacks_->streamInfo().filterState()); if (auto downstream_connection = downstreamConnection(); downstream_connection != nullptr) { if (auto typed_state = downstream_connection->streamInfo() diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index ce7a0dd45b9a..04335ecadf1b 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -377,7 +377,7 @@ Network::FilterStatus Filter::initializeUpstreamConnection() { StreamInfo::FilterState::LifeSpan::Connection); } transport_socket_options_ = Network::TransportSocketOptionsUtility::fromFilterState( - downstream_connection->streamInfo().filterState()); + read_callbacks_->connection().streamInfo().filterState()); if (auto typed_state = downstream_connection->streamInfo() .filterState() diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index ae6bf63b68ab..790fc66f6f72 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -1651,7 +1651,7 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::httpConnPoolImp bool have_transport_socket_options = false; if (context && context->upstreamTransportSocketOptions()) { - context->upstreamTransportSocketOptions()->hashKey(hash_key, host->transportSocketFactory()); + host->transportSocketFactory().hashKey(hash_key, context->upstreamTransportSocketOptions()); have_transport_socket_options = true; } @@ -1752,7 +1752,7 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::tcpConnPoolImpl bool have_transport_socket_options = false; if (context != nullptr && context->upstreamTransportSocketOptions() != nullptr) { have_transport_socket_options = true; - context->upstreamTransportSocketOptions()->hashKey(hash_key, host->transportSocketFactory()); + host->transportSocketFactory().hashKey(hash_key, context->upstreamTransportSocketOptions()); } TcpConnPoolsContainer& container = parent_.host_tcp_conn_pool_map_[host]; diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index 2086ef77546e..1b909e1fee8f 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -96,6 +96,7 @@ envoy_cc_library( "//source/common/common:empty_string", "//source/common/common:enum_to_int", "//source/common/network:raw_buffer_socket_lib", + "//source/common/network:transport_socket_options_lib", "//source/common/protobuf:utility_lib", ], ) diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index 913f587d9b3f..e7f48d4bb83c 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -5,6 +5,7 @@ #include "source/common/buffer/buffer_impl.h" #include "source/common/buffer/watermark_buffer.h" #include "source/common/network/raw_buffer_socket.h" +#include "source/common/network/transport_socket_options_impl.h" #include "source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h" #include "source/extensions/transport_sockets/alts/tsi_frame_protector.h" #include "source/extensions/transport_sockets/alts/tsi_handshaker.h" @@ -129,12 +130,12 @@ class TsiSocket : public Network::TransportSocket, /** * An implementation of Network::TransportSocketFactory for TsiSocket */ -class TsiSocketFactory : public Network::TransportSocketFactory { +class TsiSocketFactory : public Network::CommonTransportSocketFactory { public: TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); bool implementsSecureTransport() const override; - bool usesProxyProtocolOptions() const override { return false; } + Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; diff --git a/source/extensions/transport_sockets/common/BUILD b/source/extensions/transport_sockets/common/BUILD index d333c0993f44..1dbee281080e 100644 --- a/source/extensions/transport_sockets/common/BUILD +++ b/source/extensions/transport_sockets/common/BUILD @@ -16,5 +16,6 @@ envoy_cc_library( "//envoy/network:connection_interface", "//envoy/network:transport_socket_interface", "//source/common/buffer:buffer_lib", + "//source/common/network:transport_socket_options_lib", ], ) diff --git a/source/extensions/transport_sockets/common/passthrough.h b/source/extensions/transport_sockets/common/passthrough.h index cb3ecc529ab4..3ae8de97a4eb 100644 --- a/source/extensions/transport_sockets/common/passthrough.h +++ b/source/extensions/transport_sockets/common/passthrough.h @@ -4,12 +4,13 @@ #include "envoy/network/transport_socket.h" #include "source/common/buffer/buffer_impl.h" +#include "source/common/network/transport_socket_options_impl.h" namespace Envoy { namespace Extensions { namespace TransportSockets { -class PassthroughFactory : public Network::TransportSocketFactory { +class PassthroughFactory : public Network::CommonTransportSocketFactory { public: PassthroughFactory(Network::TransportSocketFactoryPtr&& transport_socket_factory) : transport_socket_factory_(std::move(transport_socket_factory)) { @@ -20,9 +21,6 @@ class PassthroughFactory : public Network::TransportSocketFactory { return transport_socket_factory_->implementsSecureTransport(); } bool supportsAlpn() const override { return transport_socket_factory_->supportsAlpn(); } - bool usesProxyProtocolOptions() const override { - return transport_socket_factory_->usesProxyProtocolOptions(); - } protected: // The wrapped factory. diff --git a/source/extensions/transport_sockets/proxy_protocol/BUILD b/source/extensions/transport_sockets/proxy_protocol/BUILD index c84e1932ebc6..2917ee384df4 100644 --- a/source/extensions/transport_sockets/proxy_protocol/BUILD +++ b/source/extensions/transport_sockets/proxy_protocol/BUILD @@ -31,6 +31,8 @@ envoy_cc_library( "//envoy/network:connection_interface", "//envoy/network:transport_socket_interface", "//source/common/buffer:buffer_lib", + "//source/common/common:scalar_to_byte_vector_lib", + "//source/common/common:utility_lib", "//source/common/network:address_lib", "//source/extensions/common/proxy_protocol:proxy_protocol_header_lib", "//source/extensions/transport_sockets/common:passthrough_lib", diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc index d14482e24c2c..7f3843a8f4aa 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc @@ -6,6 +6,8 @@ #include "envoy/network/transport_socket.h" #include "source/common/buffer/buffer_impl.h" +#include "source/common/common/scalar_to_byte_vector.h" +#include "source/common/common/utility.h" #include "source/common/network/address_impl.h" #include "source/extensions/common/proxy_protocol/proxy_protocol_header.h" @@ -119,6 +121,20 @@ Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportS config_.version()); } +void UpstreamProxyProtocolSocketFactory::hashKey( + std::vector& key, Network::TransportSocketOptionsConstSharedPtr options) const { + PassthroughFactory::hashKey(key, options); + // Proxy protocol options should only be included in the hash if the upstream + // socket intends to use them. + if (options) { + const auto& proxy_protocol_options = options->proxyProtocolOptions(); + if (proxy_protocol_options.has_value()) { + pushScalarToByteVector( + StringUtil::CaseInsensitiveHash()(proxy_protocol_options.value().asStringForHash()), key); + } + } +} + } // namespace ProxyProtocol } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h index 43f283252912..eb35d93d10fd 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h @@ -47,7 +47,8 @@ class UpstreamProxyProtocolSocketFactory : public PassthroughFactory { // Network::TransportSocketFactory Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; - bool usesProxyProtocolOptions() const override { return true; } + void hashKey(std::vector& key, + Network::TransportSocketOptionsConstSharedPtr options) const override; private: ProxyProtocolConfig config_; diff --git a/source/extensions/transport_sockets/starttls/BUILD b/source/extensions/transport_sockets/starttls/BUILD index 31fcfef60a36..197a0f584be9 100644 --- a/source/extensions/transport_sockets/starttls/BUILD +++ b/source/extensions/transport_sockets/starttls/BUILD @@ -44,6 +44,7 @@ envoy_cc_library( "//source/common/common:empty_string", "//source/common/common:minimal_logger_lib", "//source/common/common:thread_annotations", + "//source/common/network:transport_socket_options_lib", "@envoy_api//envoy/extensions/transport_sockets/starttls/v3:pkg_cc_proto", ], ) diff --git a/source/extensions/transport_sockets/starttls/starttls_socket.h b/source/extensions/transport_sockets/starttls/starttls_socket.h index 4829049f83a6..1906049aa76d 100644 --- a/source/extensions/transport_sockets/starttls/starttls_socket.h +++ b/source/extensions/transport_sockets/starttls/starttls_socket.h @@ -8,6 +8,7 @@ #include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" +#include "source/common/network/transport_socket_options_impl.h" namespace Envoy { namespace Extensions { @@ -68,7 +69,7 @@ class StartTlsSocket : public Network::TransportSocket, Logger::Loggable { public: ~StartTlsSocketFactory() override = default; @@ -81,7 +82,6 @@ class StartTlsSocketFactory : public Network::TransportSocketFactory, Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override { return false; } - bool usesProxyProtocolOptions() const override { return false; } private: Network::TransportSocketFactoryPtr raw_socket_factory_; diff --git a/source/extensions/transport_sockets/tls/BUILD b/source/extensions/transport_sockets/tls/BUILD index f7912127a864..63258c4c4812 100644 --- a/source/extensions/transport_sockets/tls/BUILD +++ b/source/extensions/transport_sockets/tls/BUILD @@ -107,6 +107,7 @@ envoy_cc_library( "//source/common/common:minimal_logger_lib", "//source/common/common:thread_annotations", "//source/common/http:headers_lib", + "//source/common/network:transport_socket_options_lib", ], ) diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index d2986e19bbff..6a5fe1aaa512 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -14,6 +14,7 @@ #include "envoy/stats/stats_macros.h" #include "source/common/common/logger.h" +#include "source/common/network/transport_socket_options_impl.h" #include "source/extensions/transport_sockets/tls/context_impl.h" #include "source/extensions/transport_sockets/tls/ssl_handshaker.h" #include "source/extensions/transport_sockets/tls/utility.h" @@ -97,7 +98,7 @@ class SslSocket : public Network::TransportSocket, SslHandshakerImplSharedPtr info_; }; -class ClientSslSocketFactory : public Network::TransportSocketFactory, +class ClientSslSocketFactory : public Network::CommonTransportSocketFactory, public Secret::SecretCallbacks, Logger::Loggable { public: @@ -109,7 +110,6 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override; - bool usesProxyProtocolOptions() const override { return false; } bool supportsAlpn() const override { return true; } // Secret::SecretCallbacks @@ -128,7 +128,7 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, Envoy::Ssl::ClientContextSharedPtr ssl_ctx_ ABSL_GUARDED_BY(ssl_ctx_mu_); }; -class ServerSslSocketFactory : public Network::TransportSocketFactory, +class ServerSslSocketFactory : public Network::CommonTransportSocketFactory, public Secret::SecretCallbacks, Logger::Loggable { public: @@ -141,7 +141,6 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, Network::TransportSocketPtr createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options) const override; bool implementsSecureTransport() const override; - bool usesProxyProtocolOptions() const override { return false; } // Secret::SecretCallbacks void onAddOrUpdateSecret() override; diff --git a/test/common/http/conn_pool_grid_test.cc b/test/common/http/conn_pool_grid_test.cc index 790981c42978..945279c94fa8 100644 --- a/test/common/http/conn_pool_grid_test.cc +++ b/test/common/http/conn_pool_grid_test.cc @@ -792,7 +792,6 @@ TEST_F(ConnectivityGridTest, RealGrid) { auto factory = std::make_unique(std::move(config), factory_context); factory->initialize(); - ASSERT_FALSE(factory->usesProxyProtocolOptions()); auto& matcher = static_cast(*cluster_->transport_socket_matcher_); EXPECT_CALL(matcher, resolve(_)) @@ -837,7 +836,6 @@ TEST_F(ConnectivityGridTest, ConnectionCloseDuringCreation) { auto factory = std::make_unique(std::move(config), factory_context); factory->initialize(); - ASSERT_FALSE(factory->usesProxyProtocolOptions()); auto& matcher = static_cast(*cluster_->transport_socket_matcher_); EXPECT_CALL(matcher, resolve(_)) @@ -909,7 +907,6 @@ TEST_F(ConnectivityGridTest, ConnectionCloseDuringAysnConnect) { auto factory = std::make_unique(std::move(config), factory_context); factory->initialize(); - ASSERT_FALSE(factory->usesProxyProtocolOptions()); auto& matcher = static_cast(*cluster_->transport_socket_matcher_); EXPECT_CALL(matcher, resolve(_)) diff --git a/test/common/network/BUILD b/test/common/network/BUILD index 0afaf9f64096..39ca99aeeebd 100644 --- a/test/common/network/BUILD +++ b/test/common/network/BUILD @@ -188,6 +188,7 @@ envoy_cc_test( srcs = ["raw_buffer_socket_test.cc"], deps = [ "//source/common/network:raw_buffer_socket_lib", + "//source/common/network:transport_socket_options_lib", "//test/test_common:network_utility_lib", ], ) diff --git a/test/common/network/raw_buffer_socket_test.cc b/test/common/network/raw_buffer_socket_test.cc index 4304277b1b2d..46456985483f 100644 --- a/test/common/network/raw_buffer_socket_test.cc +++ b/test/common/network/raw_buffer_socket_test.cc @@ -1,4 +1,5 @@ #include "source/common/network/raw_buffer_socket.h" +#include "source/common/network/transport_socket_options_impl.h" #include "test/test_common/network_utility.h" @@ -9,7 +10,13 @@ namespace Network { TEST(RawBufferSocketFactory, RawBufferSocketFactory) { TransportSocketFactoryPtr factory = Envoy::Network::Test::createRawBufferSocketFactory(); - EXPECT_FALSE(factory->usesProxyProtocolOptions()); + EXPECT_FALSE(factory->implementsSecureTransport()); + std::vector keys; + factory->hashKey(keys, nullptr); + EXPECT_EQ(keys.size(), 0); + auto options = std::make_shared("server"); + factory->hashKey(keys, options); + EXPECT_GT(keys.size(), 0); } } // namespace Network diff --git a/test/common/network/transport_socket_options_impl_test.cc b/test/common/network/transport_socket_options_impl_test.cc index 8bf0b4bd2f16..7044054d982f 100644 --- a/test/common/network/transport_socket_options_impl_test.cc +++ b/test/common/network/transport_socket_options_impl_test.cc @@ -15,32 +15,34 @@ namespace { class TransportSocketOptionsImplTest : public testing::Test { public: TransportSocketOptionsImplTest() - : filter_state_(StreamInfo::FilterState::LifeSpan::FilterChain) {} + : filter_state_(std::make_shared( + StreamInfo::FilterState::LifeSpan::FilterChain)) {} protected: - StreamInfo::FilterStateImpl filter_state_; + StreamInfo::FilterStateSharedPtr filter_state_; }; TEST_F(TransportSocketOptionsImplTest, Nullptr) { - EXPECT_EQ(nullptr, TransportSocketOptionsUtility::fromFilterState(filter_state_)); - filter_state_.setData( + EXPECT_EQ(nullptr, TransportSocketOptionsUtility::fromFilterState(nullptr)); + filter_state_->setData( "random_key_has_no_effect", std::make_unique("www.example.com"), StreamInfo::FilterState::StateType::ReadOnly, StreamInfo::FilterState::LifeSpan::FilterChain); - EXPECT_EQ(nullptr, TransportSocketOptionsUtility::fromFilterState(filter_state_)); + auto transport_socket_options = TransportSocketOptionsUtility::fromFilterState(filter_state_); + EXPECT_TRUE(transport_socket_options->filterState()->hasDataWithName("random_key_has_no_effect")); } TEST_F(TransportSocketOptionsImplTest, UpstreamServer) { - filter_state_.setData( + filter_state_->setData( UpstreamServerName::key(), std::make_unique("www.example.com"), StreamInfo::FilterState::StateType::ReadOnly, StreamInfo::FilterState::LifeSpan::FilterChain); - filter_state_.setData(ProxyProtocolFilterState::key(), - std::make_unique(Network::ProxyProtocolData{ - Network::Address::InstanceConstSharedPtr( - new Network::Address::Ipv4Instance("202.168.0.13", 52000)), - Network::Address::InstanceConstSharedPtr( - new Network::Address::Ipv4Instance("174.2.2.222", 80))}), - StreamInfo::FilterState::StateType::ReadOnly, - StreamInfo::FilterState::LifeSpan::FilterChain); + filter_state_->setData(ProxyProtocolFilterState::key(), + std::make_unique(Network::ProxyProtocolData{ + Network::Address::InstanceConstSharedPtr( + new Network::Address::Ipv4Instance("202.168.0.13", 52000)), + Network::Address::InstanceConstSharedPtr( + new Network::Address::Ipv4Instance("174.2.2.222", 80))}), + StreamInfo::FilterState::StateType::ReadOnly, + StreamInfo::FilterState::LifeSpan::FilterChain); auto transport_socket_options = TransportSocketOptionsUtility::fromFilterState(filter_state_); EXPECT_EQ(absl::make_optional("www.example.com"), transport_socket_options->serverNameOverride()); @@ -52,7 +54,7 @@ TEST_F(TransportSocketOptionsImplTest, UpstreamServer) { TEST_F(TransportSocketOptionsImplTest, ApplicationProtocols) { std::vector http_alpns{Http::Utility::AlpnNames::get().Http2, Http::Utility::AlpnNames::get().Http11}; - filter_state_.setData( + filter_state_->setData( ApplicationProtocols::key(), std::make_unique(http_alpns), StreamInfo::FilterState::StateType::ReadOnly, StreamInfo::FilterState::LifeSpan::FilterChain); auto transport_socket_options = TransportSocketOptionsUtility::fromFilterState(filter_state_); @@ -63,10 +65,10 @@ TEST_F(TransportSocketOptionsImplTest, ApplicationProtocols) { TEST_F(TransportSocketOptionsImplTest, Both) { std::vector http_alpns{Http::Utility::AlpnNames::get().Http2, Http::Utility::AlpnNames::get().Http11}; - filter_state_.setData( + filter_state_->setData( UpstreamServerName::key(), std::make_unique("www.example.com"), StreamInfo::FilterState::StateType::ReadOnly, StreamInfo::FilterState::LifeSpan::FilterChain); - filter_state_.setData( + filter_state_->setData( ApplicationProtocols::key(), std::make_unique(http_alpns), StreamInfo::FilterState::StateType::ReadOnly, StreamInfo::FilterState::LifeSpan::FilterChain); auto transport_socket_options = TransportSocketOptionsUtility::fromFilterState(filter_state_); diff --git a/test/common/upstream/transport_socket_matcher_test.cc b/test/common/upstream/transport_socket_matcher_test.cc index 4a3f8af0690c..b1ae65848262 100644 --- a/test/common/upstream/transport_socket_matcher_test.cc +++ b/test/common/upstream/transport_socket_matcher_test.cc @@ -31,10 +31,11 @@ namespace { class FakeTransportSocketFactory : public Network::TransportSocketFactory { public: MOCK_METHOD(bool, implementsSecureTransport, (), (const)); - MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(bool, supportsAlpn, (), (const)); MOCK_METHOD(Network::TransportSocketPtr, createTransportSocket, (Network::TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, hashKey, (std::vector&, Network::TransportSocketOptionsConstSharedPtr), + (const)); FakeTransportSocketFactory(std::string id, bool alpn) : supports_alpn_(alpn), id_(std::move(id)) { ON_CALL(*this, supportsAlpn).WillByDefault(Invoke([this]() { return supports_alpn_; })); } @@ -52,9 +53,10 @@ class FooTransportSocketFactory Logger::Loggable { public: MOCK_METHOD(bool, implementsSecureTransport, (), (const)); - MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(Network::TransportSocketPtr, createTransportSocket, (Network::TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, hashKey, (std::vector&, Network::TransportSocketOptionsConstSharedPtr), + (const)); Network::TransportSocketFactoryPtr createTransportSocketFactory(const Protobuf::Message& proto, diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index cb847c65a010..8dbf6f4c270b 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -917,8 +917,10 @@ TEST_F(TsiSocketFactoryTest, ImplementsSecureTransport) { EXPECT_TRUE(socket_factory_->implementsSecureTransport()); } -TEST_F(TsiSocketFactoryTest, UsesProxyProtocolOptions) { - EXPECT_FALSE(socket_factory_->usesProxyProtocolOptions()); +TEST_F(TsiSocketFactoryTest, HashKey) { + std::vector key; + socket_factory_->hashKey(key, nullptr); + EXPECT_EQ(0, key.size()); } } // namespace diff --git a/test/extensions/transport_sockets/common/passthrough_test.cc b/test/extensions/transport_sockets/common/passthrough_test.cc index 9172f54c93f5..7b055f6f1077 100644 --- a/test/extensions/transport_sockets/common/passthrough_test.cc +++ b/test/extensions/transport_sockets/common/passthrough_test.cc @@ -110,8 +110,9 @@ TEST(PassthroughFactoryTest, TestDelegation) { factory->supportsAlpn(); } { - EXPECT_CALL(*inner_factory, usesProxyProtocolOptions()); - factory->usesProxyProtocolOptions(); + std::vector key; + EXPECT_CALL(*inner_factory, hashKey(_, _)); + factory->hashKey(key, nullptr); } } diff --git a/test/extensions/transport_sockets/starttls/starttls_socket_test.cc b/test/extensions/transport_sockets/starttls/starttls_socket_test.cc index 2a817a044d02..768cb13560e2 100644 --- a/test/extensions/transport_sockets/starttls/starttls_socket_test.cc +++ b/test/extensions/transport_sockets/starttls/starttls_socket_test.cc @@ -120,7 +120,9 @@ TEST(StartTls, BasicFactoryTest) { Network::TransportSocketFactoryPtr(raw_buffer_factory), Network::TransportSocketFactoryPtr(ssl_factory)); ASSERT_FALSE(factory->implementsSecureTransport()); - ASSERT_FALSE(factory->usesProxyProtocolOptions()); + std::vector key; + factory->hashKey(key, nullptr); + EXPECT_EQ(0, key.size()); } } // namespace StartTls diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index 06b61503fe81..af185f838c38 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -679,7 +679,6 @@ void testUtilV2(const TestUtilOptionsV2& options) { ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, server_stats_store, server_names); - EXPECT_FALSE(server_ssl_socket_factory.usesProxyProtocolOptions()); Event::DispatcherPtr dispatcher(server_api->allocateDispatcher("test_thread")); auto socket = std::make_shared( diff --git a/test/mocks/network/transport_socket.h b/test/mocks/network/transport_socket.h index 3716b783d3a0..4ad6cb3bfa61 100644 --- a/test/mocks/network/transport_socket.h +++ b/test/mocks/network/transport_socket.h @@ -39,10 +39,11 @@ class MockTransportSocketFactory : public TransportSocketFactory { ~MockTransportSocketFactory() override; MOCK_METHOD(bool, implementsSecureTransport, (), (const)); - MOCK_METHOD(bool, usesProxyProtocolOptions, (), (const)); MOCK_METHOD(bool, supportsAlpn, (), (const)); MOCK_METHOD(TransportSocketPtr, createTransportSocket, (TransportSocketOptionsConstSharedPtr), (const)); + MOCK_METHOD(void, hashKey, + (std::vector & key, TransportSocketOptionsConstSharedPtr options), (const)); }; } // namespace Network