diff --git a/envoy/network/socket.h b/envoy/network/socket.h index 692670e651695..ed37195de7088 100644 --- a/envoy/network/socket.h +++ b/envoy/network/socket.h @@ -46,6 +46,8 @@ struct SocketOptionName { * Interfaces for providing a socket's various addresses. This is split into a getters interface * and a getters + setters interface. This is so that only the getters portion can be overridden * in certain cases. + * TODO(soulxu): Since there are more than address information inside the provider, this will be + * renamed as ConnectionInfoProvider. Ref https://github.com/envoyproxy/envoy/issues/17168 */ class SocketAddressProvider { public: @@ -73,6 +75,11 @@ class SocketAddressProvider { */ virtual const Address::InstanceConstSharedPtr& directRemoteAddress() const PURE; + /** + * @return SNI value for downstream host. + */ + virtual absl::string_view requestedServerName() const PURE; + /** * Dumps the state of the SocketAddressProvider to the given ostream. * @@ -109,6 +116,11 @@ class SocketAddressSetter : public SocketAddressProvider { * Set the remote address of the socket. */ virtual void setRemoteAddress(const Address::InstanceConstSharedPtr& remote_address) PURE; + + /** + * @param SNI value requested. + */ + virtual void setRequestedServerName(const absl::string_view requested_server_name) PURE; }; using SocketAddressSetterSharedPtr = std::shared_ptr; diff --git a/envoy/stream_info/stream_info.h b/envoy/stream_info/stream_info.h index 2b146b6e421a0..002f0ce8bdf63 100644 --- a/envoy/stream_info/stream_info.h +++ b/envoy/stream_info/stream_info.h @@ -529,16 +529,6 @@ class StreamInfo { virtual const FilterStateSharedPtr& upstreamFilterState() const PURE; virtual void setUpstreamFilterState(const FilterStateSharedPtr& filter_state) PURE; - /** - * @param SNI value requested. - */ - virtual void setRequestedServerName(const absl::string_view requested_server_name) PURE; - - /** - * @return SNI value for downstream host. - */ - virtual const std::string& requestedServerName() const PURE; - /** * @param failure_reason the upstream transport failure reason. */ diff --git a/source/common/formatter/substitution_formatter.cc b/source/common/formatter/substitution_formatter.cc index 76261d9f882ac..c8419263f367c 100644 --- a/source/common/formatter/substitution_formatter.cc +++ b/source/common/formatter/substitution_formatter.cc @@ -820,8 +820,8 @@ StreamInfoFormatter::StreamInfoFormatter(const std::string& field_name) { field_extractor_ = std::make_unique( [](const StreamInfo::StreamInfo& stream_info) { absl::optional result; - if (!stream_info.requestedServerName().empty()) { - result = stream_info.requestedServerName(); + if (!stream_info.downstreamAddressProvider().requestedServerName().empty()) { + result = std::string(stream_info.downstreamAddressProvider().requestedServerName()); } return result; }); diff --git a/source/common/http/conn_manager_impl.cc b/source/common/http/conn_manager_impl.cc index 12c29a1e0cbda..e792b2735d269 100644 --- a/source/common/http/conn_manager_impl.cc +++ b/source/common/http/conn_manager_impl.cc @@ -687,9 +687,6 @@ ConnectionManagerImpl::ActiveStream::ActiveStream(ConnectionManagerImpl& connect max_stream_duration_timer_->enableTimer(connection_manager_.config_.maxStreamDuration().value(), this); } - - filter_manager_.streamInfo().setRequestedServerName( - connection_manager_.read_callbacks_->connection().requestedServerName()); } void ConnectionManagerImpl::ActiveStream::completeRequest() { diff --git a/source/common/http/filter_manager.h b/source/common/http/filter_manager.h index 1e826f1df5d1c..008baa50bcd81 100644 --- a/source/common/http/filter_manager.h +++ b/source/common/http/filter_manager.h @@ -618,7 +618,9 @@ class OverridableRemoteSocketAddressSetterStreamInfo : public StreamInfo::Stream const Network::Address::InstanceConstSharedPtr& directRemoteAddress() const override { return StreamInfoImpl::downstreamAddressProvider().directRemoteAddress(); } - + absl::string_view requestedServerName() const override { + return StreamInfoImpl::downstreamAddressProvider().requestedServerName(); + } void dumpState(std::ostream& os, int indent_level) const override { StreamInfoImpl::dumpState(os, indent_level); diff --git a/source/common/network/listen_socket_impl.h b/source/common/network/listen_socket_impl.h index d42e8cadd683a..65e09b2b461ae 100644 --- a/source/common/network/listen_socket_impl.h +++ b/source/common/network/listen_socket_impl.h @@ -152,9 +152,11 @@ class ConnectionSocketImpl : public SocketImpl, public ConnectionSocket { void setRequestedServerName(absl::string_view server_name) override { // Always keep the server_name_ as lower case. - server_name_ = absl::AsciiStrToLower(server_name); + addressProvider().setRequestedServerName(absl::AsciiStrToLower(server_name)); + } + absl::string_view requestedServerName() const override { + return addressProvider().requestedServerName(); } - absl::string_view requestedServerName() const override { return server_name_; } absl::optional lastRoundTripTime() override { return ioHandle().lastRoundTripTime(); @@ -162,15 +164,13 @@ class ConnectionSocketImpl : public SocketImpl, public ConnectionSocket { void dumpState(std::ostream& os, int indent_level) const override { const char* spaces = spacesForLevel(indent_level); - os << spaces << "ListenSocketImpl " << this << DUMP_MEMBER(transport_protocol_) - << DUMP_MEMBER(server_name_) << "\n"; + os << spaces << "ListenSocketImpl " << this << DUMP_MEMBER(transport_protocol_) << "\n"; DUMP_DETAILS(address_provider_); } protected: std::string transport_protocol_; std::vector application_protocols_; - std::string server_name_; }; // ConnectionSocket used with server connections. diff --git a/source/common/network/socket_impl.h b/source/common/network/socket_impl.h index fd2765646f1ce..d277f536c3349 100644 --- a/source/common/network/socket_impl.h +++ b/source/common/network/socket_impl.h @@ -24,7 +24,8 @@ class SocketAddressSetterImpl : public SocketAddressSetter { os << spaces << "SocketAddressSetterImpl " << this << DUMP_NULLABLE_MEMBER(remote_address_, remote_address_->asStringView()) << DUMP_NULLABLE_MEMBER(direct_remote_address_, direct_remote_address_->asStringView()) - << DUMP_NULLABLE_MEMBER(local_address_, local_address_->asStringView()) << "\n"; + << DUMP_NULLABLE_MEMBER(local_address_, local_address_->asStringView()) + << DUMP_MEMBER(server_name_) << "\n"; } // SocketAddressSetter @@ -44,12 +45,17 @@ class SocketAddressSetterImpl : public SocketAddressSetter { const Address::InstanceConstSharedPtr& directRemoteAddress() const override { return direct_remote_address_; } + absl::string_view requestedServerName() const override { return server_name_; } + void setRequestedServerName(const absl::string_view requested_server_name) override { + server_name_ = std::string(requested_server_name); + } private: Address::InstanceConstSharedPtr local_address_; bool local_address_restored_{false}; Address::InstanceConstSharedPtr remote_address_; Address::InstanceConstSharedPtr direct_remote_address_; + std::string server_name_; }; class SocketImpl : public virtual Socket { diff --git a/source/common/stream_info/stream_info_impl.h b/source/common/stream_info/stream_info_impl.h index 633178d2c361b..1a41097e51420 100644 --- a/source/common/stream_info/stream_info_impl.h +++ b/source/common/stream_info/stream_info_impl.h @@ -233,12 +233,6 @@ struct StreamInfoImpl : public StreamInfo { upstream_filter_state_ = filter_state; } - void setRequestedServerName(absl::string_view requested_server_name) override { - requested_server_name_ = std::string(requested_server_name); - } - - const std::string& requestedServerName() const override { return requested_server_name_; } - void setUpstreamTransportFailureReason(absl::string_view failure_reason) override { upstream_transport_failure_reason_ = std::string(failure_reason); } diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index 37f71070dd9d2..ce3f18d7d65c2 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -655,9 +655,9 @@ void Filter::onUpstreamConnection() { read_callbacks_->upstreamHost()->outlierDetector().putResult( Upstream::Outlier::Result::LocalOriginConnectSuccessFinal); - getStreamInfo().setRequestedServerName(read_callbacks_->connection().requestedServerName()); ENVOY_CONN_LOG(debug, "TCP:onUpstreamEvent(), requestedServerName: {}", - read_callbacks_->connection(), getStreamInfo().requestedServerName()); + read_callbacks_->connection(), + getStreamInfo().downstreamAddressProvider().requestedServerName()); if (config_->idleTimeout()) { // The idle_timer_ can be moved to a Drainer, so related callbacks call into diff --git a/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc b/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc index 845412359782b..64e9923aaf450 100644 --- a/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc +++ b/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc @@ -171,7 +171,8 @@ void Utility::extractCommonAccessLogProperties( const Ssl::ConnectionInfoConstSharedPtr downstream_ssl_connection = stream_info.downstreamSslConnection(); - tls_properties->set_tls_sni_hostname(stream_info.requestedServerName()); + tls_properties->set_tls_sni_hostname( + std::string(stream_info.downstreamAddressProvider().requestedServerName())); auto* local_properties = tls_properties->mutable_local_certificate_properties(); for (const auto& uri_san : downstream_ssl_connection->uriSanLocalCertificate()) { diff --git a/source/extensions/filters/common/expr/context.cc b/source/extensions/filters/common/expr/context.cc index 66603460003c6..d073b3226494f 100644 --- a/source/extensions/filters/common/expr/context.cc +++ b/source/extensions/filters/common/expr/context.cc @@ -184,7 +184,7 @@ absl::optional ConnectionWrapper::operator[](CelValue key) const { return CelValue::CreateBool(info_.downstreamSslConnection() != nullptr && info_.downstreamSslConnection()->peerCertificatePresented()); } else if (value == RequestedServerName) { - return CelValue::CreateString(&info_.requestedServerName()); + return CelValue::CreateStringView(info_.downstreamAddressProvider().requestedServerName()); } else if (value == ID) { auto id = info_.connectionID(); if (id.has_value()) { diff --git a/source/extensions/filters/http/lua/wrappers.cc b/source/extensions/filters/http/lua/wrappers.cc index 04bd6a7b3eac5..bf70a111e2180 100644 --- a/source/extensions/filters/http/lua/wrappers.cc +++ b/source/extensions/filters/http/lua/wrappers.cc @@ -141,7 +141,7 @@ int StreamInfoWrapper::luaDownstreamDirectRemoteAddress(lua_State* state) { } int StreamInfoWrapper::luaRequestedServerName(lua_State* state) { - lua_pushstring(state, stream_info_.requestedServerName().c_str()); + lua_pushstring(state, stream_info_.downstreamAddressProvider().requestedServerName().data()); return 1; } diff --git a/test/common/formatter/substitution_formatter_test.cc b/test/common/formatter/substitution_formatter_test.cc index ed1d38fbbea91..72fdcafdfef38 100644 --- a/test/common/formatter/substitution_formatter_test.cc +++ b/test/common/formatter/substitution_formatter_test.cc @@ -674,8 +674,7 @@ TEST(SubstitutionFormatterTest, streamInfoFormatter) { { StreamInfoFormatter upstream_format("REQUESTED_SERVER_NAME"); std::string requested_server_name = "stub_server"; - EXPECT_CALL(stream_info, requestedServerName()) - .WillRepeatedly(ReturnRef(requested_server_name)); + stream_info.downstream_address_provider_->setRequestedServerName(requested_server_name); EXPECT_EQ("stub_server", upstream_format.format(request_headers, response_headers, response_trailers, stream_info, body)); EXPECT_THAT(upstream_format.formatValue(request_headers, response_headers, response_trailers, @@ -686,8 +685,7 @@ TEST(SubstitutionFormatterTest, streamInfoFormatter) { { StreamInfoFormatter upstream_format("REQUESTED_SERVER_NAME"); std::string requested_server_name; - EXPECT_CALL(stream_info, requestedServerName()) - .WillRepeatedly(ReturnRef(requested_server_name)); + stream_info.downstream_address_provider_->setRequestedServerName(requested_server_name); EXPECT_EQ(absl::nullopt, upstream_format.format(request_headers, response_headers, response_trailers, stream_info, body)); EXPECT_THAT(upstream_format.formatValue(request_headers, response_headers, response_trailers, diff --git a/test/common/http/conn_manager_impl_test.cc b/test/common/http/conn_manager_impl_test.cc index 08f56b0a338fb..164f7d46f947d 100644 --- a/test/common/http/conn_manager_impl_test.cc +++ b/test/common/http/conn_manager_impl_test.cc @@ -293,11 +293,7 @@ TEST_F(HttpConnectionManagerImplTest, 100ContinueResponseWithDecoderPause) { // When create new stream, the stream info will be populated from the connection. TEST_F(HttpConnectionManagerImplTest, PopulateStreamInfo) { setup(true, "", false); - - absl::string_view server_name = "fake-server"; EXPECT_CALL(filter_callbacks_.connection_, id()).WillRepeatedly(Return(1234)); - EXPECT_CALL(filter_callbacks_.connection_, requestedServerName()) - .WillRepeatedly(Return(server_name)); // Set up the codec. Buffer::OwnedImpl fake_input("input"); @@ -308,7 +304,7 @@ TEST_F(HttpConnectionManagerImplTest, PopulateStreamInfo) { EXPECT_EQ(requestIDExtension().get(), decoder_->streamInfo().getRequestIDProvider()); EXPECT_EQ(ssl_connection_, decoder_->streamInfo().downstreamSslConnection()); EXPECT_EQ(1234U, decoder_->streamInfo().connectionID()); - EXPECT_EQ(server_name, decoder_->streamInfo().requestedServerName()); + EXPECT_EQ(server_name_, decoder_->streamInfo().downstreamAddressProvider().requestedServerName()); // Clean up. filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); diff --git a/test/common/http/conn_manager_impl_test_base.cc b/test/common/http/conn_manager_impl_test_base.cc index 5cce6546afe8a..9f7b81bb97ae1 100644 --- a/test/common/http/conn_manager_impl_test_base.cc +++ b/test/common/http/conn_manager_impl_test_base.cc @@ -70,6 +70,8 @@ void HttpConnectionManagerImplTest::setup(bool ssl, const std::string& server_na std::make_shared("0.0.0.0")); filter_callbacks_.connection_.stream_info_.downstream_address_provider_ ->setDirectRemoteAddressForTest(std::make_shared("0.0.0.0")); + filter_callbacks_.connection_.stream_info_.downstream_address_provider_->setRequestedServerName( + server_name_); conn_manager_ = std::make_unique( *this, drain_close_, random_, http_context_, runtime_, local_info_, cluster_manager_, overload_manager_, test_time_.timeSystem()); diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index 63f5a10898aef..0e1dfa8d5e35d 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -1913,19 +1913,19 @@ TEST_P(ConnectionImplTest, NetworkSocketDumpsWithoutAllocatingMemory) { // Check socket dump const auto contents = ostream.contents(); EXPECT_THAT(contents, HasSubstr("ListenSocketImpl")); - EXPECT_THAT(contents, HasSubstr("transport_protocol_: , server_name_: envoyproxy.io")); + EXPECT_THAT(contents, HasSubstr("transport_protocol_: ")); EXPECT_THAT(contents, HasSubstr("SocketAddressSetterImpl")); if (GetParam() == Network::Address::IpVersion::v4) { EXPECT_THAT( contents, HasSubstr( "remote_address_: 1.1.1.1:80, direct_remote_address_: 1.1.1.1:80, local_address_: " - "1.2.3.4:56789")); + "1.2.3.4:56789, server_name_: envoyproxy.io")); } else { EXPECT_THAT( contents, HasSubstr("remote_address_: [::1]:80, direct_remote_address_: [::1]:80, local_address_: " - "[::1:2:3:4]:56789")); + "[::1:2:3:4]:56789, server_name_: envoyproxy.io")); } } diff --git a/test/common/stream_info/stream_info_impl_test.cc b/test/common/stream_info/stream_info_impl_test.cc index 88b2c04fb38c9..482490ca7874d 100644 --- a/test/common/stream_info/stream_info_impl_test.cc +++ b/test/common/stream_info/stream_info_impl_test.cc @@ -180,11 +180,6 @@ TEST_F(StreamInfoImplTest, MiscSettersAndGetters) { EXPECT_EQ(1, stream_info.upstreamFilterState()->getDataReadOnly("test").access()); - EXPECT_EQ("", stream_info.requestedServerName()); - absl::string_view sni_name = "stubserver.org"; - stream_info.setRequestedServerName(sni_name); - EXPECT_EQ(std::string(sni_name), stream_info.requestedServerName()); - EXPECT_EQ(absl::nullopt, stream_info.upstreamClusterInfo()); Upstream::ClusterInfoConstSharedPtr cluster_info(new NiceMock()); stream_info.setUpstreamClusterInfo(cluster_info); diff --git a/test/common/stream_info/test_util.h b/test/common/stream_info/test_util.h index 8335d6317b749..39b95e2b76549 100644 --- a/test/common/stream_info/test_util.h +++ b/test/common/stream_info/test_util.h @@ -178,12 +178,6 @@ class TestStreamInfo : public StreamInfo::StreamInfo { upstream_filter_state_ = filter_state; } - void setRequestedServerName(const absl::string_view requested_server_name) override { - requested_server_name_ = std::string(requested_server_name); - } - - const std::string& requestedServerName() const override { return requested_server_name_; } - void setUpstreamTransportFailureReason(absl::string_view failure_reason) override { upstream_transport_failure_reason_ = std::string(failure_reason); } diff --git a/test/extensions/access_loggers/grpc/http_grpc_access_log_impl_test.cc b/test/extensions/access_loggers/grpc/http_grpc_access_log_impl_test.cc index 42969c35dfa4e..54e012875513c 100644 --- a/test/extensions/access_loggers/grpc/http_grpc_access_log_impl_test.cc +++ b/test/extensions/access_loggers/grpc/http_grpc_access_log_impl_test.cc @@ -389,7 +389,7 @@ response: {} ON_CALL(*connection_info, tlsVersion()).WillByDefault(ReturnRef(tlsVersion)); ON_CALL(*connection_info, ciphersuiteId()).WillByDefault(Return(0x2CC0)); stream_info.setDownstreamSslConnection(connection_info); - stream_info.requested_server_name_ = "sni"; + stream_info.downstream_address_provider_->setRequestedServerName("sni"); Http::TestRequestHeaderMapImpl request_headers{ {":method", "WHACKADOO"}, @@ -449,7 +449,7 @@ response: {} ON_CALL(*connection_info, tlsVersion()).WillByDefault(ReturnRef(tlsVersion)); ON_CALL(*connection_info, ciphersuiteId()).WillByDefault(Return(0x2F)); stream_info.setDownstreamSslConnection(connection_info); - stream_info.requested_server_name_ = "sni"; + stream_info.downstream_address_provider_->setRequestedServerName("sni"); Http::TestRequestHeaderMapImpl request_headers{ {":method", "WHACKADOO"}, @@ -499,7 +499,7 @@ response: {} ON_CALL(*connection_info, tlsVersion()).WillByDefault(ReturnRef(tlsVersion)); ON_CALL(*connection_info, ciphersuiteId()).WillByDefault(Return(0x2F)); stream_info.setDownstreamSslConnection(connection_info); - stream_info.requested_server_name_ = "sni"; + stream_info.downstream_address_provider_->setRequestedServerName("sni"); Http::TestRequestHeaderMapImpl request_headers{ {":method", "WHACKADOO"}, @@ -549,7 +549,7 @@ response: {} ON_CALL(*connection_info, tlsVersion()).WillByDefault(ReturnRef(tlsVersion)); ON_CALL(*connection_info, ciphersuiteId()).WillByDefault(Return(0x2F)); stream_info.setDownstreamSslConnection(connection_info); - stream_info.requested_server_name_ = "sni"; + stream_info.downstream_address_provider_->setRequestedServerName("sni"); Http::TestRequestHeaderMapImpl request_headers{ {":method", "WHACKADOO"}, @@ -599,7 +599,7 @@ response: {} ON_CALL(*connection_info, tlsVersion()).WillByDefault(ReturnRef(tlsVersion)); ON_CALL(*connection_info, ciphersuiteId()).WillByDefault(Return(0x2F)); stream_info.setDownstreamSslConnection(connection_info); - stream_info.requested_server_name_ = "sni"; + stream_info.downstream_address_provider_->setRequestedServerName("sni"); Http::TestRequestHeaderMapImpl request_headers{ {":method", "WHACKADOO"}, diff --git a/test/extensions/filters/common/expr/context_test.cc b/test/extensions/filters/common/expr/context_test.cc index dc3e58b62bbb1..79475ef353c70 100644 --- a/test/extensions/filters/common/expr/context_test.cc +++ b/test/extensions/filters/common/expr/context_test.cc @@ -440,10 +440,10 @@ TEST(Context, ConnectionAttributes) { const std::string sni_name = "kittens.com"; info.downstream_address_provider_->setLocalAddress(local); info.downstream_address_provider_->setRemoteAddress(remote); + info.downstream_address_provider_->setRequestedServerName(sni_name); EXPECT_CALL(info, downstreamSslConnection()).WillRepeatedly(Return(downstream_ssl_info)); EXPECT_CALL(info, upstreamSslConnection()).WillRepeatedly(Return(upstream_ssl_info)); EXPECT_CALL(info, upstreamHost()).WillRepeatedly(Return(upstream_host)); - EXPECT_CALL(info, requestedServerName()).WillRepeatedly(ReturnRef(sni_name)); EXPECT_CALL(info, upstreamLocalAddress()).WillRepeatedly(ReturnRef(upstream_local_address)); const std::string upstream_transport_failure_reason = "ConnectionTermination"; EXPECT_CALL(info, upstreamTransportFailureReason()) diff --git a/test/extensions/filters/http/lua/lua_filter_test.cc b/test/extensions/filters/http/lua/lua_filter_test.cc index d1529ac2061e1..82d1623ff02cd 100644 --- a/test/extensions/filters/http/lua/lua_filter_test.cc +++ b/test/extensions/filters/http/lua/lua_filter_test.cc @@ -1794,8 +1794,8 @@ TEST_F(LuaHttpFilterTest, GetRequestedServerName) { setup(SCRIPT); EXPECT_CALL(decoder_callbacks_, streamInfo()).WillOnce(ReturnRef(stream_info_)); - std::string server_name = "foo.example.com"; - EXPECT_CALL(stream_info_, requestedServerName()).WillOnce(ReturnRef(server_name)); + absl::string_view server_name = "foo.example.com"; + stream_info_.downstream_address_provider_->setRequestedServerName(server_name); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("foo.example.com"))); diff --git a/test/extensions/filters/http/lua/wrappers_test.cc b/test/extensions/filters/http/lua/wrappers_test.cc index ca10e89bf79b9..7b7b940149062 100644 --- a/test/extensions/filters/http/lua/wrappers_test.cc +++ b/test/extensions/filters/http/lua/wrappers_test.cc @@ -310,7 +310,7 @@ TEST_F(LuaStreamInfoWrapperTest, ReturnRequestedServerName) { setup(SCRIPT); NiceMock stream_info; - stream_info.requested_server_name_ = "some.sni.io"; + stream_info.downstream_address_provider_->setRequestedServerName("some.sni.io"); Filters::Common::Lua::LuaDeathRef wrapper( StreamInfoWrapper::create(coroutine_->luaState(), stream_info), true); EXPECT_CALL(printer_, testPrint("some.sni.io")); diff --git a/test/extensions/filters/http/wasm/wasm_filter_test.cc b/test/extensions/filters/http/wasm/wasm_filter_test.cc index 98d8c1c2671d7..43d7ebe70beb2 100644 --- a/test/extensions/filters/http/wasm/wasm_filter_test.cc +++ b/test/extensions/filters/http/wasm/wasm_filter_test.cc @@ -1682,7 +1682,7 @@ TEST_P(WasmHttpFilterTest, Property) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, true)); StreamInfo::MockStreamInfo log_stream_info; request_stream_info_.route_name_ = "route12"; - request_stream_info_.requested_server_name_ = "w3.org"; + request_stream_info_.downstream_address_provider_->setRequestedServerName("w3.org"); NiceMock connection; EXPECT_CALL(connection, id()).WillRepeatedly(Return(4)); EXPECT_CALL(encoder_callbacks_, connection()).WillRepeatedly(Return(&connection)); diff --git a/test/fuzz/utility.h b/test/fuzz/utility.h index 002eab8571a74..b478b9eab1faf 100644 --- a/test/fuzz/utility.h +++ b/test/fuzz/utility.h @@ -152,7 +152,6 @@ inline std::unique_ptr fromStreamInfo(const test::fuzz::StreamIn if (stream_info.has_response_code()) { test_stream_info->response_code_ = stream_info.response_code().value(); } - test_stream_info->setRequestedServerName(stream_info.requested_server_name()); auto upstream_host = std::make_shared>(); auto upstream_metadata = std::make_shared( replaceInvalidStringValues(stream_info.upstream_metadata())); @@ -168,6 +167,8 @@ inline std::unique_ptr fromStreamInfo(const test::fuzz::StreamIn test_stream_info->upstream_local_address_ = upstream_local_address; test_stream_info->downstream_address_provider_ = std::make_shared(address, address); + test_stream_info->downstream_address_provider_->setRequestedServerName( + stream_info.requested_server_name()); auto connection_info = std::make_shared>(); ON_CALL(*connection_info, subjectPeerCertificate()) .WillByDefault(testing::ReturnRef(TestSubjectPeer)); diff --git a/test/mocks/stream_info/mocks.cc b/test/mocks/stream_info/mocks.cc index 3907bb515f131..a2b7c768162d3 100644 --- a/test/mocks/stream_info/mocks.cc +++ b/test/mocks/stream_info/mocks.cc @@ -112,11 +112,6 @@ MockStreamInfo::MockStreamInfo() .WillByDefault(Invoke([this](const FilterStateSharedPtr& filter_state) { upstream_filter_state_ = filter_state; })); - ON_CALL(*this, setRequestedServerName(_)) - .WillByDefault(Invoke([this](const absl::string_view requested_server_name) { - requested_server_name_ = std::string(requested_server_name); - })); - ON_CALL(*this, requestedServerName()).WillByDefault(ReturnRef(requested_server_name_)); ON_CALL(*this, setRouteName(_)).WillByDefault(Invoke([this](const absl::string_view route_name) { route_name_ = std::string(route_name); })); diff --git a/test/mocks/stream_info/mocks.h b/test/mocks/stream_info/mocks.h index 1a0a81b36d9a5..dbc2fec6c5bd4 100644 --- a/test/mocks/stream_info/mocks.h +++ b/test/mocks/stream_info/mocks.h @@ -80,8 +80,6 @@ class MockStreamInfo : public StreamInfo { MOCK_METHOD(const FilterState&, filterState, (), (const)); MOCK_METHOD(const FilterStateSharedPtr&, upstreamFilterState, (), (const)); MOCK_METHOD(void, setUpstreamFilterState, (const FilterStateSharedPtr&)); - MOCK_METHOD(void, setRequestedServerName, (const absl::string_view)); - MOCK_METHOD(const std::string&, requestedServerName, (), (const)); MOCK_METHOD(void, setUpstreamTransportFailureReason, (absl::string_view)); MOCK_METHOD(const std::string&, upstreamTransportFailureReason, (), (const)); MOCK_METHOD(void, setRequestHeaders, (const Http::RequestHeaderMap&)); @@ -127,7 +125,6 @@ class MockStreamInfo : public StreamInfo { std::shared_ptr downstream_address_provider_; Ssl::ConnectionInfoConstSharedPtr downstream_connection_info_; Ssl::ConnectionInfoConstSharedPtr upstream_connection_info_; - std::string requested_server_name_; std::string route_name_; std::string upstream_transport_failure_reason_; std::string filter_chain_name_; diff --git a/test/server/BUILD b/test/server/BUILD index acfa8ec40fc7a..3b0005b9dad48 100644 --- a/test/server/BUILD +++ b/test/server/BUILD @@ -105,6 +105,7 @@ envoy_cc_test( "//source/server:connection_handler_lib", "//test/mocks/access_log:access_log_mocks", "//test/mocks/api:api_mocks", + "//test/mocks/network:io_handle_mocks", "//test/mocks/network:network_mocks", "//test/test_common:network_utility_lib", "//test/test_common:threadsafe_singleton_injector_lib", diff --git a/test/server/active_tcp_listener_test.cc b/test/server/active_tcp_listener_test.cc index 9d15d452475b6..baaf57421ec08 100644 --- a/test/server/active_tcp_listener_test.cc +++ b/test/server/active_tcp_listener_test.cc @@ -12,6 +12,7 @@ #include "test/mocks/api/mocks.h" #include "test/mocks/common.h" +#include "test/mocks/network/io_handle.h" #include "test/mocks/network/mocks.h" #include "test/test_common/network_utility.h" @@ -37,6 +38,7 @@ class MockTcpConnectionHandler : public Network::TcpConnectionHandler, MOCK_METHOD(Network::BalancedConnectionHandlerOptRef, getBalancedHandlerByAddress, (const Network::Address::Instance& address)); }; + class ActiveTcpListenerTest : public testing::Test, protected Logger::Loggable { public: ActiveTcpListenerTest() { @@ -60,6 +62,62 @@ class ActiveTcpListenerTest : public testing::Test, protected Logger::Loggable> listener_filter_matcher_; }; +TEST_F(ActiveTcpListenerTest, PopulateSNIWhenActiveTcpSocketTimeout) { + NiceMock balancer; + EXPECT_CALL(listener_config_, connectionBalancer()).WillRepeatedly(ReturnRef(balancer)); + EXPECT_CALL(listener_config_, listenerScope).Times(testing::AnyNumber()); + EXPECT_CALL(listener_config_, listenerFiltersTimeout()) + .WillOnce(Return(std::chrono::milliseconds(1000))); + EXPECT_CALL(listener_config_, continueOnListenerFiltersTimeout()); + EXPECT_CALL(listener_config_, openConnections()).WillRepeatedly(ReturnRef(resource_limit_)); + + auto listener = std::make_unique>(); + EXPECT_CALL(*listener, onDestroy()); + + auto* test_filter = new NiceMock(); + EXPECT_CALL(*test_filter, destroy_()); + EXPECT_CALL(listener_config_, filterChainFactory()) + .WillRepeatedly(ReturnRef(filter_chain_factory_)); + + // add a filter to stop the filter iteration. + EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{test_filter}); + return true; + })); + EXPECT_CALL(*test_filter, onAccept(_)) + .WillOnce(Invoke([](Network::ListenerFilterCallbacks&) -> Network::FilterStatus { + return Network::FilterStatus::StopIteration; + })); + + auto active_listener = + std::make_unique(conn_handler_, std::move(listener), listener_config_); + + absl::string_view server_name = "envoy.io"; + auto accepted_socket = std::make_unique>(); + accepted_socket->address_provider_->setRequestedServerName(server_name); + + // fake the socket is open. + NiceMock io_handle; + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)); + EXPECT_CALL(io_handle, isOpen()).WillOnce(Return(true)); + + EXPECT_CALL(balancer, pickTargetHandler(_)) + .WillOnce(testing::DoAll( + testing::WithArg<0>(Invoke([](auto& target) { target.incNumConnections(); })), + ReturnRef(*active_listener))); + + // calling the onAcceptWorker() to create the ActiveTcpSocket. + active_listener->onAcceptWorker(std::move(accepted_socket), false, false); + // get the ActiveTcpSocket pointer before unlink() removed from the link-list. + ActiveTcpSocket* tcp_socket = active_listener->sockets_.front().get(); + // trigger the onTimeout event manually, since the timer is fake. + active_listener->sockets_.front()->onTimeout(); + + EXPECT_EQ(server_name, + tcp_socket->stream_info_->downstreamAddressProvider().requestedServerName()); +} + // Verify that the server connection with recovered address is rebalanced at redirected listener. TEST_F(ActiveTcpListenerTest, RedirectedRebalancer) { NiceMock listener_config1;