diff --git a/envoy/ssl/connection.h b/envoy/ssl/connection.h index 501a99b80f280..42d8cf3f29554 100644 --- a/envoy/ssl/connection.h +++ b/envoy/ssl/connection.h @@ -142,6 +142,11 @@ class ConnectionInfo { * @return std::string the protocol negotiated via ALPN. **/ virtual const std::string& alpn() const PURE; + + /** + * @return std::string the SNI used to establish the connection. + **/ + virtual const std::string& sni() const PURE; }; using ConnectionInfoConstSharedPtr = std::shared_ptr; diff --git a/source/common/http/conn_pool_grid.cc b/source/common/http/conn_pool_grid.cc index 733bde76c57ad..0b0bda8598fa4 100644 --- a/source/common/http/conn_pool_grid.cc +++ b/source/common/http/conn_pool_grid.cc @@ -18,6 +18,17 @@ absl::string_view describePool(const ConnectionPool::Instance& pool) { static constexpr uint32_t kDefaultTimeoutMs = 300; +std::string getSni(const Network::TransportSocketOptionsConstSharedPtr& options, + Network::TransportSocketFactory& transport_socket_factory) { + if (options && options->serverNameOverride().has_value()) { + return options->serverNameOverride().value(); + } + auto* quic_socket_factory = + dynamic_cast(&transport_socket_factory); + ASSERT(quic_socket_factory != nullptr); + return quic_socket_factory->clientContextConfig().serverNameIndication(); +} + } // namespace ConnectivityGrid::WrapperCallbacks::WrapperCallbacks(ConnectivityGrid& grid, @@ -105,7 +116,7 @@ void ConnectivityGrid::WrapperCallbacks::deleteThis() { ConnectivityGrid::StreamCreationResult ConnectivityGrid::WrapperCallbacks::newStream() { ENVOY_LOG(trace, "{} pool attempting to create a new stream to host '{}'.", - describePool(**current_), grid_.host_->hostname()); + describePool(**current_), grid_.origin_.hostname_); auto attempt = std::make_unique(*this, current_); LinkedList::moveIntoList(std::move(attempt), connection_attempts_); if (!next_attempt_timer_->enabled()) { @@ -145,7 +156,7 @@ void ConnectivityGrid::WrapperCallbacks::onConnectionAttemptReady( void ConnectivityGrid::WrapperCallbacks::maybeMarkHttp3Broken() { if (http3_attempt_failed_ && tcp_attempt_succeeded_) { - ENVOY_LOG(trace, "Marking HTTP/3 broken for host '{}'.", grid_.host_->hostname()); + ENVOY_LOG(trace, "Marking HTTP/3 broken for host '{}'.", grid_.origin_.hostname_); grid_.markHttp3Broken(); } } @@ -206,15 +217,14 @@ ConnectivityGrid::ConnectivityGrid( priority_(priority), options_(options), transport_socket_options_(transport_socket_options), state_(state), next_attempt_duration_(std::chrono::milliseconds(kDefaultTimeoutMs)), time_source_(time_source), alternate_protocols_(alternate_protocols), - quic_stat_names_(quic_stat_names), scope_(scope), quic_info_(quic_info) { + quic_stat_names_(quic_stat_names), scope_(scope), + origin_("https", getSni(transport_socket_options, host_->transportSocketFactory()), + host_->address()->ip()->port()), + quic_info_(quic_info) { ASSERT(connectivity_options.protocols_.size() == 3); ASSERT(alternate_protocols); - // ProdClusterManagerFactory::allocateConnPool verifies the protocols are HTTP/1, HTTP/2 and - // HTTP/3. - AlternateProtocolsCache::Origin origin("https", host_->hostname(), - host_->address()->ip()->port()); std::chrono::milliseconds rtt = - std::chrono::duration_cast(alternate_protocols_->getSrtt(origin)); + std::chrono::duration_cast(alternate_protocols_->getSrtt(origin_)); if (rtt.count() != 0) { next_attempt_duration_ = std::chrono::milliseconds(rtt.count() * 2); } @@ -394,13 +404,11 @@ bool ConnectivityGrid::shouldAttemptHttp3() { return false; } uint32_t port = host_->address()->ip()->port(); - // TODO(RyanTheOptimist): Figure out how scheme gets plumbed in here. - AlternateProtocolsCache::Origin origin("https", host_->hostname(), port); OptRef> protocols = - alternate_protocols_->findAlternatives(origin); + alternate_protocols_->findAlternatives(origin_); if (!protocols.has_value()) { ENVOY_LOG(trace, "No alternate protocols available for host '{}', skipping HTTP/3.", - host_->hostname()); + origin_.hostname_); return false; } if (isHttp3Broken()) { @@ -412,7 +420,7 @@ bool ConnectivityGrid::shouldAttemptHttp3() { if (!protocol.hostname_.empty() || protocol.port_ != port) { ENVOY_LOG(trace, "Alternate protocol for host '{}' attempts to change host or port, skipping.", - host_->hostname()); + origin_.hostname_); continue; } @@ -424,20 +432,20 @@ bool ConnectivityGrid::shouldAttemptHttp3() { alt_svc, quic::CurrentSupportedVersions()); if (version != quic::ParsedQuicVersion::Unsupported()) { // TODO(RyanTheOptimist): Pass this version down to the HTTP/3 pool. - ENVOY_LOG(trace, "HTTP/3 advertised for host '{}'", host_->hostname()); + ENVOY_LOG(trace, "HTTP/3 advertised for host '{}'", origin_.hostname_); return true; } ENVOY_LOG(trace, "Alternate protocol for host '{}' has unsupported ALPN '{}', skipping.", - host_->hostname(), protocol.alpn_); + origin_.hostname_, protocol.alpn_); } - ENVOY_LOG(trace, "HTTP/3 is not available to host '{}', skipping.", host_->hostname()); + ENVOY_LOG(trace, "HTTP/3 is not available to host '{}', skipping.", origin_.hostname_); return false; } void ConnectivityGrid::onHandshakeComplete() { - ENVOY_LOG(trace, "Marking HTTP/3 confirmed for host '{}'.", host_->hostname()); + ENVOY_LOG(trace, "Marking HTTP/3 confirmed for host '{}'.", origin_.hostname_); markHttp3Confirmed(); } diff --git a/source/common/http/conn_pool_grid.h b/source/common/http/conn_pool_grid.h index 9f18595254a55..2f297737609f1 100644 --- a/source/common/http/conn_pool_grid.h +++ b/source/common/http/conn_pool_grid.h @@ -235,6 +235,10 @@ class ConnectivityGrid : public ConnectionPool::Instance, Quic::QuicStatNames& quic_stat_names_; Stats::Scope& scope_; + // The origin for this pool. + // Note the host name here is based off of the host name used for SNI, which + // may be from the cluster config, or the request headers for auto-sni. + AlternateProtocolsCache::Origin origin_; Http::PersistentQuicInfo& quic_info_; }; diff --git a/source/common/http/http3/conn_pool.cc b/source/common/http/http3/conn_pool.cc index c16a4e8ef1053..466a058cd24f0 100644 --- a/source/common/http/http3/conn_pool.cc +++ b/source/common/http/http3/conn_pool.cc @@ -22,6 +22,19 @@ uint32_t getMaxStreams(const Upstream::ClusterInfo& cluster) { max_concurrent_streams, 100); } +const Envoy::Ssl::ClientContextConfig& +getConfig(Network::TransportSocketFactory& transport_socket_factory) { + return dynamic_cast(transport_socket_factory) + .clientContextConfig(); +} + +std::string sni(const Network::TransportSocketOptionsConstSharedPtr& options, + Upstream::HostConstSharedPtr host) { + return options && options->serverNameOverride().has_value() + ? options->serverNameOverride().value() + : getConfig(host->transportSocketFactory()).serverNameIndication(); +} + } // namespace ActiveClient::ActiveClient(Envoy::Http::HttpConnPoolImplBase& parent, @@ -54,12 +67,6 @@ void ActiveClient::onMaxStreamsChanged(uint32_t num_streams) { } } -const Envoy::Ssl::ClientContextConfig& -getConfig(Network::TransportSocketFactory& transport_socket_factory) { - return dynamic_cast(transport_socket_factory) - .clientContextConfig(); -} - ConnectionPool::Cancellable* Http3ConnPoolImpl::newStream(Http::ResponseDecoder& response_decoder, ConnectionPool::Callbacks& callbacks, const Instance::StreamOptions& options) { @@ -78,7 +85,7 @@ Http3ConnPoolImpl::Http3ConnPoolImpl( : FixedHttpConnPoolImpl(host, priority, dispatcher, options, transport_socket_options, random_generator, state, client_fn, codec_fn, protocol), quic_info_(dynamic_cast(quic_info)), - server_id_(getConfig(host_->transportSocketFactory()).serverNameIndication(), + server_id_(sni(transport_socket_options, host), static_cast(host_->address()->ip()->port()), false), connect_callback_(connect_callback) {} diff --git a/source/extensions/filters/http/alternate_protocols_cache/filter.cc b/source/extensions/filters/http/alternate_protocols_cache/filter.cc index d9f09a15fbee4..49cf9f55cf87a 100644 --- a/source/extensions/filters/http/alternate_protocols_cache/filter.cc +++ b/source/extensions/filters/http/alternate_protocols_cache/filter.cc @@ -66,8 +66,14 @@ Http::FilterHeadersStatus Filter::encodeHeaders(Http::ResponseHeaderMap& headers // balanced across them. Upstream::HostDescriptionConstSharedPtr host = encoder_callbacks_->streamInfo().upstreamInfo()->upstreamHost(); + absl::string_view hostname = host->hostname(); + if (encoder_callbacks_->streamInfo().upstreamInfo()->upstreamSslConnection() && + !encoder_callbacks_->streamInfo().upstreamInfo()->upstreamSslConnection()->sni().empty()) { + // In the case the configured hostname and SNI differ, prefer SNI where + // available. + hostname = encoder_callbacks_->streamInfo().upstreamInfo()->upstreamSslConnection()->sni(); + } const uint32_t port = host->address()->ip()->port(); - const std::string& hostname = host->hostname(); Http::AlternateProtocolsCache::Origin origin(Http::Headers::get().SchemeValues.Https, hostname, port); cache_->setAlternatives(origin, protocols); diff --git a/source/extensions/transport_sockets/tls/connection_info_impl_base.cc b/source/extensions/transport_sockets/tls/connection_info_impl_base.cc index 3aa9974ff8b93..8c1ea7b697ac7 100644 --- a/source/extensions/transport_sockets/tls/connection_info_impl_base.cc +++ b/source/extensions/transport_sockets/tls/connection_info_impl_base.cc @@ -202,6 +202,16 @@ const std::string& ConnectionInfoImplBase::alpn() const { return alpn_; } +const std::string& ConnectionInfoImplBase::sni() const { + if (sni_.empty()) { + const char* proto = SSL_get_servername(ssl(), TLSEXT_NAMETYPE_host_name); + if (proto != nullptr) { + sni_ = std::string(proto); + } + } + return sni_; +} + const std::string& ConnectionInfoImplBase::serialNumberPeerCertificate() const { if (!cached_serial_number_peer_certificate_.empty()) { return cached_serial_number_peer_certificate_; diff --git a/source/extensions/transport_sockets/tls/connection_info_impl_base.h b/source/extensions/transport_sockets/tls/connection_info_impl_base.h index f5bfa73b0ee1d..8334ee174a2b6 100644 --- a/source/extensions/transport_sockets/tls/connection_info_impl_base.h +++ b/source/extensions/transport_sockets/tls/connection_info_impl_base.h @@ -39,6 +39,7 @@ class ConnectionInfoImplBase : public Ssl::ConnectionInfo { std::string ciphersuiteString() const override; const std::string& tlsVersion() const override; const std::string& alpn() const override; + const std::string& sni() const override; virtual SSL* ssl() const PURE; @@ -58,6 +59,7 @@ class ConnectionInfoImplBase : public Ssl::ConnectionInfo { mutable std::string cached_session_id_; mutable std::string cached_tls_version_; mutable std::string alpn_; + mutable std::string sni_; }; } // namespace Tls diff --git a/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc b/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc index f474d36810d71..427e403231ce0 100644 --- a/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc +++ b/test/extensions/filters/http/alternate_protocols_cache/filter_integration_test.cc @@ -169,7 +169,6 @@ TEST_P(FilterIntegrationTest, H3PostHandshakeFailoverToTcp) { ASSERT_TRUE(fake_upstream_connection_->close()); test_server_->waitForCounterEq("cluster.cluster_0.upstream_cx_destroy", 1); fake_upstream_connection_.reset(); - // Second request should go out over HTTP/3 because of the Alt-Svc information. auto response2 = codec_client_->makeHeaderOnlyRequest(request_headers); waitForNextUpstreamRequest(1); @@ -200,22 +199,14 @@ INSTANTIATE_TEST_SUITE_P(Protocols, FilterIntegrationTest, // an HTTP/2 or an HTTP/3 upstream (but not both). class MixedUpstreamIntegrationTest : public FilterIntegrationTest { protected: - void initialize() override { - // TODO(alyssawilk) there's no config guarantee that SNI and hostname - // match, but alt-svc rtt caching doesn't work unless they do. Fix. - config_helper_.addConfigModifier( - [&](envoy::config::bootstrap::v3::Bootstrap& bootstrap) -> void { - auto cluster = bootstrap.mutable_static_resources()->mutable_clusters(0); - auto locality_lb = cluster->mutable_load_assignment()->mutable_endpoints(0); - auto endpoint = locality_lb->mutable_lb_endpoints(0)->mutable_endpoint(); - endpoint->set_hostname("foo.lyft.com"); - }); - FilterIntegrationTest::initialize(); + MixedUpstreamIntegrationTest() { + TestEnvironment::writeStringToFileForTest("alt_svc_cache.txt", ""); + default_request_headers_.setHost("sni.lyft.com"); } void writeFile() { uint32_t port = fake_upstreams_[0]->localAddress()->ip()->port(); - std::string key = absl::StrCat("https://foo.lyft.com:", port); + std::string key = absl::StrCat("https://sni.lyft.com:", port); size_t seconds = std::chrono::duration_cast( timeSystem().monotonicTime().time_since_epoch()) @@ -250,10 +241,13 @@ int getSrtt(std::string alt_svc, TimeSource& time_source) { /*from_cache=*/false); return data.has_value() ? data.value().srtt.count() : 0; } + // Test auto-config with a pre-populated HTTP/3 alt-svc entry. The upstream request will // occur over HTTP/3. TEST_P(MixedUpstreamIntegrationTest, BasicRequestAutoWithHttp3) { - testRouterRequestAndResponseWithBody(0, 0, false); + initialize(); + codec_client_ = makeHttpConnection(makeClientConnection((lookupPort("http")))); + sendRequestAndWaitForResponse(default_request_headers_, 0, default_response_headers_, 0, 0); cleanupUpstreamAndDownstream(); std::string alt_svc; @@ -288,7 +282,9 @@ TEST_P(MixedUpstreamIntegrationTest, SimultaneousLargeRequestsAutoWithHttp3) { TEST_P(MixedUpstreamIntegrationTest, BasicRequestAutoWithHttp2) { // Only create an HTTP/2 upstream. use_http2_ = true; - testRouterRequestAndResponseWithBody(0, 0, false); + initialize(); + codec_client_ = makeHttpConnection(makeClientConnection((lookupPort("http")))); + sendRequestAndWaitForResponse(default_request_headers_, 0, default_response_headers_, 0, 0); } // Same as above, only multiple requests. diff --git a/test/extensions/filters/http/alternate_protocols_cache/filter_test.cc b/test/extensions/filters/http/alternate_protocols_cache/filter_test.cc index dd11f5717fe88..7e415c24d630d 100644 --- a/test/extensions/filters/http/alternate_protocols_cache/filter_test.cc +++ b/test/extensions/filters/http/alternate_protocols_cache/filter_test.cc @@ -105,7 +105,7 @@ TEST_F(FilterTest, ValidAltSvc) { std::shared_ptr hd = std::make_shared(); testing::NiceMock stream_info; - EXPECT_CALL(callbacks_, streamInfo()).WillOnce(ReturnRef(stream_info)); + EXPECT_CALL(callbacks_, streamInfo()).Times(2).WillOnce(ReturnRef(stream_info)); stream_info.upstreamInfo()->setUpstreamHost(hd); EXPECT_CALL(*hd, hostname()).WillOnce(ReturnRef(hostname)); EXPECT_CALL(*hd, address()).WillOnce(Return(address));