diff --git a/source/common/network/utility.cc b/source/common/network/utility.cc index 5b7ef1730c4a8..886eb027c7783 100644 --- a/source/common/network/utility.cc +++ b/source/common/network/utility.cc @@ -346,14 +346,12 @@ Address::InstanceConstSharedPtr Utility::getIpv6LoopbackAddress() { new Address::Ipv6Instance("::1", 0, nullptr)); } -Address::InstanceConstSharedPtr Utility::getIpv4AnyAddress() { - CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr, - new Address::Ipv4Instance(static_cast(0))); +Address::InstanceConstSharedPtr Utility::getIpv4AnyAddress(uint32_t port) { + CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr, new Address::Ipv4Instance(port)); } -Address::InstanceConstSharedPtr Utility::getIpv6AnyAddress() { - CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr, - new Address::Ipv6Instance(static_cast(0))); +Address::InstanceConstSharedPtr Utility::getIpv6AnyAddress(uint32_t port) { + CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr, new Address::Ipv6Instance(port)); } const std::string& Utility::getIpv4CidrCatchAllAddress() { diff --git a/source/common/network/utility.h b/source/common/network/utility.h index d042d130071f6..721140b530def 100644 --- a/source/common/network/utility.h +++ b/source/common/network/utility.h @@ -246,18 +246,20 @@ class Utility { static Address::InstanceConstSharedPtr getIpv6LoopbackAddress(); /** + * @param port to be included in address, the default is 0. * @return Address::InstanceConstSharedPtr an address that represents the IPv4 wildcard address * (i.e. "0.0.0.0"). Used during binding to indicate that incoming connections to any * local IPv4 address are to be accepted. */ - static Address::InstanceConstSharedPtr getIpv4AnyAddress(); + static Address::InstanceConstSharedPtr getIpv4AnyAddress(uint32_t port = 0); /** + * @param port to be included in address, the default is 0. * @return Address::InstanceConstSharedPtr an address that represents the IPv6 wildcard address * (i.e. "::"). Used during binding to indicate that incoming connections to any local * IPv6 address are to be accepted. */ - static Address::InstanceConstSharedPtr getIpv6AnyAddress(); + static Address::InstanceConstSharedPtr getIpv6AnyAddress(uint32_t port = 0); /** * @return the IPv4 CIDR catch-all address (0.0.0.0/0). diff --git a/source/server/connection_handler_impl.cc b/source/server/connection_handler_impl.cc index 24789e5feb5aa..89533d944818c 100644 --- a/source/server/connection_handler_impl.cc +++ b/source/server/connection_handler_impl.cc @@ -39,60 +39,84 @@ void ConnectionHandlerImpl::addListener(absl::optional overridden_list return; } - ActiveListenerDetails details; + auto details = std::make_shared(); if (config.internalListenerConfig().has_value()) { if (overridden_listener.has_value()) { - for (auto& listener : listeners_) { - if (listener.second.listener_->listenerTag() == overridden_listener) { - listener.second.internalListener()->get().updateListenerConfig(config); - return; - } + if (auto iter = listener_map_by_tag_.find(overridden_listener.value()); + iter != listener_map_by_tag_.end()) { + iter->second->internalListener()->get().updateListenerConfig(config); + return; } NOT_REACHED_GCOVR_EXCL_LINE; } auto internal_listener = std::make_unique(*this, dispatcher(), config); - details.typed_listener_ = *internal_listener; - details.listener_ = std::move(internal_listener); + details->typed_listener_ = *internal_listener; + details->listener_ = std::move(internal_listener); } else if (config.listenSocketFactory().socketType() == Network::Socket::Type::Stream) { if (!support_udp_in_place_filter_chain_update && overridden_listener.has_value()) { - for (auto& listener : listeners_) { - if (listener.second.listener_->listenerTag() == overridden_listener) { - listener.second.tcpListener()->get().updateListenerConfig(config); - return; - } + if (auto iter = listener_map_by_tag_.find(overridden_listener.value()); + iter != listener_map_by_tag_.end()) { + iter->second->tcpListener()->get().updateListenerConfig(config); + return; } NOT_REACHED_GCOVR_EXCL_LINE; } // worker_index_ doesn't have a value on the main thread for the admin server. auto tcp_listener = std::make_unique( *this, config, worker_index_.has_value() ? *worker_index_ : 0); - details.typed_listener_ = *tcp_listener; - details.listener_ = std::move(tcp_listener); + details->typed_listener_ = *tcp_listener; + details->listener_ = std::move(tcp_listener); } else { ASSERT(config.udpListenerConfig().has_value(), "UDP listener factory is not initialized."); ASSERT(worker_index_.has_value()); ConnectionHandler::ActiveUdpListenerPtr udp_listener = config.udpListenerConfig()->listenerFactory().createActiveUdpListener(*worker_index_, *this, dispatcher_, config); - details.typed_listener_ = *udp_listener; - details.listener_ = std::move(udp_listener); + details->typed_listener_ = *udp_listener; + details->listener_ = std::move(udp_listener); } + if (disable_listeners_) { - details.listener_->pauseListening(); + details->listener_->pauseListening(); } - if (auto* listener = details.listener_->listener(); listener != nullptr) { + if (auto* listener = details->listener_->listener(); listener != nullptr) { listener->setRejectFraction(listener_reject_fraction_); } - listeners_.emplace_back(config.listenSocketFactory().localAddress(), std::move(details)); + + details->listener_tag_ = config.listenerTag(); + details->address_ = config.listenSocketFactory().localAddress(); + + ASSERT(!listener_map_by_tag_.contains(config.listenerTag())); + + listener_map_by_tag_.emplace(config.listenerTag(), details); + // This map only store the new listener. + if (absl::holds_alternative>( + details->typed_listener_)) { + tcp_listener_map_by_address_.insert_or_assign( + config.listenSocketFactory().localAddress()->asStringView(), details); + } else if (absl::holds_alternative>( + details->typed_listener_)) { + internal_listener_map_by_address_.insert_or_assign( + config.listenSocketFactory().localAddress()->asStringView(), details); + } } void ConnectionHandlerImpl::removeListeners(uint64_t listener_tag) { - for (auto listener = listeners_.begin(); listener != listeners_.end();) { - if (listener->second.listener_->listenerTag() == listener_tag) { - listener = listeners_.erase(listener); - } else { - ++listener; + if (auto listener_iter = listener_map_by_tag_.find(listener_tag); + listener_iter != listener_map_by_tag_.end()) { + // listener_map_by_address_ may already update to the new listener. Compare it with the one + // which find from listener_map_by_tag_, only delete it when it is same listener. + auto address_view = listener_iter->second->address_->asStringView(); + if (tcp_listener_map_by_address_.contains(address_view) && + tcp_listener_map_by_address_[address_view]->listener_tag_ == + listener_iter->second->listener_tag_) { + tcp_listener_map_by_address_.erase(address_view); + } else if (internal_listener_map_by_address_.contains(address_view) && + internal_listener_map_by_address_[address_view]->listener_tag_ == + listener_iter->second->listener_tag_) { + internal_listener_map_by_address_.erase(address_view); } + listener_map_by_tag_.erase(listener_iter); } } @@ -112,12 +136,11 @@ ConnectionHandlerImpl::getUdpListenerCallbacks(uint64_t listener_tag) { void ConnectionHandlerImpl::removeFilterChains( uint64_t listener_tag, const std::list& filter_chains, std::function completion) { - for (auto& listener : listeners_) { - if (listener.second.listener_->listenerTag() == listener_tag) { - listener.second.listener_->onFilterChainDraining(filter_chains); - break; - } + if (auto listener_it = listener_map_by_tag_.find(listener_tag); + listener_it != listener_map_by_tag_.end()) { + listener_it->second->listener_->onFilterChainDraining(filter_chains); } + // Reach here if the target listener is found or the target listener was removed by a full // listener update. In either case, the completion must be deferred so that any active connection // referencing the filter chain can finish prior to deletion. @@ -125,55 +148,54 @@ void ConnectionHandlerImpl::removeFilterChains( } void ConnectionHandlerImpl::stopListeners(uint64_t listener_tag) { - for (auto& listener : listeners_) { - if (listener.second.listener_->listenerTag() == listener_tag) { - listener.second.listener_->shutdownListener(); + if (auto iter = listener_map_by_tag_.find(listener_tag); iter != listener_map_by_tag_.end()) { + if (iter->second->listener_->listener() != nullptr) { + iter->second->listener_->shutdownListener(); } } } void ConnectionHandlerImpl::stopListeners() { - for (auto& listener : listeners_) { - listener.second.listener_->shutdownListener(); + for (auto& iter : listener_map_by_tag_) { + if (iter.second->listener_->listener() != nullptr) { + iter.second->listener_->shutdownListener(); + } } } void ConnectionHandlerImpl::disableListeners() { disable_listeners_ = true; - for (auto& listener : listeners_) { - listener.second.listener_->pauseListening(); + for (auto& iter : listener_map_by_tag_) { + if (iter.second->listener_->listener() != nullptr) { + iter.second->listener_->pauseListening(); + } } } void ConnectionHandlerImpl::enableListeners() { disable_listeners_ = false; - for (auto& listener : listeners_) { - listener.second.listener_->resumeListening(); + for (auto& iter : listener_map_by_tag_) { + if (iter.second->listener_->listener() != nullptr) { + iter.second->listener_->resumeListening(); + } } } void ConnectionHandlerImpl::setListenerRejectFraction(UnitFloat reject_fraction) { listener_reject_fraction_ = reject_fraction; - for (auto& listener : listeners_) { - listener.second.listener_->listener()->setRejectFraction(reject_fraction); + for (auto& iter : listener_map_by_tag_) { + if (iter.second->listener_->listener() != nullptr) { + iter.second->listener_->listener()->setRejectFraction(reject_fraction); + } } } Network::InternalListenerOptRef ConnectionHandlerImpl::findByAddress(const Network::Address::InstanceConstSharedPtr& address) { ASSERT(address->type() == Network::Address::Type::EnvoyInternal); - auto listener_it = - std::find_if(listeners_.begin(), listeners_.end(), - [&address](std::pair& p) { - return p.second.internalListener().has_value() && - p.second.listener_->listener() != nullptr && - p.first->type() == Network::Address::Type::EnvoyInternal && - *(p.first) == *address; - }); - - if (listener_it != listeners_.end()) { - return Network::InternalListenerOptRef(listener_it->second.internalListener().value().get()); + if (auto listener_it = internal_listener_map_by_address_.find(address->asStringView()); + listener_it != internal_listener_map_by_address_.end()) { + return Network::InternalListenerOptRef(listener_it->second->internalListener().value().get()); } return OptRef(); } @@ -198,15 +220,9 @@ ConnectionHandlerImpl::ActiveListenerDetails::internalListener() { ConnectionHandlerImpl::ActiveListenerDetailsOptRef ConnectionHandlerImpl::findActiveListenerByTag(uint64_t listener_tag) { - // TODO(mattklein123): We should probably use a hash table here to lookup the tag - // instead of iterating through the listener list. - for (auto& listener : listeners_) { - if (listener.second.listener_->listener() != nullptr && - listener.second.listener_->listenerTag() == listener_tag) { - return listener.second; - } + if (auto iter = listener_map_by_tag_.find(listener_tag); iter != listener_map_by_tag_.end()) { + return *iter->second; } - return absl::nullopt; } @@ -224,58 +240,43 @@ ConnectionHandlerImpl::getBalancedHandlerByTag(uint64_t listener_tag) { Network::BalancedConnectionHandlerOptRef ConnectionHandlerImpl::getBalancedHandlerByAddress(const Network::Address::Instance& address) { - // This is a linear operation, may need to add a map to improve performance. - // However, linear performance might be adequate since the number of listeners is small. // We do not return stopped listeners. - auto listener_it = - std::find_if(listeners_.begin(), listeners_.end(), - [&address](std::pair& p) { - return p.second.tcpListener().has_value() && - p.second.listener_->listener() != nullptr && - p.first->type() == Network::Address::Type::Ip && *(p.first) == address; - }); - // If there is exact address match, return the corresponding listener. - if (listener_it != listeners_.end()) { + if (auto listener_it = tcp_listener_map_by_address_.find(address.asStringView()); + listener_it != tcp_listener_map_by_address_.end()) { return Network::BalancedConnectionHandlerOptRef( - listener_it->second.tcpListener().value().get()); + listener_it->second->tcpListener().value().get()); } + OptRef details; // Otherwise, we need to look for the wild card match, i.e., 0.0.0.0:[address_port]. // We do not return stopped listeners. // TODO(wattli): consolidate with previous search for more efficiency. if (Runtime::runtimeFeatureEnabled( "envoy.reloadable_features.listener_wildcard_match_ip_family")) { - listener_it = - std::find_if(listeners_.begin(), listeners_.end(), - [&address](const std::pair& p) { - return absl::holds_alternative>( - p.second.typed_listener_) && - p.second.listener_->listener() != nullptr && - p.first->type() == Network::Address::Type::Ip && - p.first->ip()->port() == address.ip()->port() && - p.first->ip()->isAnyAddress() && - p.first->ip()->version() == address.ip()->version(); - }); + std::string addr_str = + address.ip()->version() == Network::Address::IpVersion::v4 + ? Network::Utility::getIpv4AnyAddress(address.ip()->port())->asString() + : Network::Utility::getIpv6AnyAddress(address.ip()->port())->asString(); + + auto iter = tcp_listener_map_by_address_.find(addr_str); + if (iter != tcp_listener_map_by_address_.end()) { + details = *iter->second; + } } else { - listener_it = - std::find_if(listeners_.begin(), listeners_.end(), - [&address](const std::pair& p) { - return absl::holds_alternative>( - p.second.typed_listener_) && - p.second.listener_->listener() != nullptr && - p.first->type() == Network::Address::Type::Ip && - p.first->ip()->port() == address.ip()->port() && - p.first->ip()->isAnyAddress(); - }); + for (auto& iter : tcp_listener_map_by_address_) { + if (iter.second->listener_->listener() != nullptr && + iter.second->address_->type() == Network::Address::Type::Ip && + iter.second->address_->ip()->port() == address.ip()->port() && + iter.second->address_->ip()->isAnyAddress()) { + details = *iter.second; + } + } } - return (listener_it != listeners_.end()) + return (details.has_value()) ? Network::BalancedConnectionHandlerOptRef( ActiveTcpListenerOptRef(absl::get>( - listener_it->second.typed_listener_)) + details->typed_listener_)) .value() .get()) : absl::nullopt; diff --git a/source/server/connection_handler_impl.h b/source/server/connection_handler_impl.h index 23a39659c55d1..e700bf04b9115 100644 --- a/source/server/connection_handler_impl.h +++ b/source/server/connection_handler_impl.h @@ -75,6 +75,8 @@ class ConnectionHandlerImpl : public Network::TcpConnectionHandler, struct ActiveListenerDetails { // Strong pointer to the listener, whether TCP, UDP, QUIC, etc. Network::ConnectionHandler::ActiveListenerPtr listener_; + Network::Address::InstanceConstSharedPtr address_; + uint64_t listener_tag_; absl::variant, std::reference_wrapper, @@ -93,7 +95,12 @@ class ConnectionHandlerImpl : public Network::TcpConnectionHandler, const absl::optional worker_index_; Event::Dispatcher& dispatcher_; const std::string per_handler_stat_prefix_; - std::list> listeners_; + absl::flat_hash_map> listener_map_by_tag_; + absl::flat_hash_map> + tcp_listener_map_by_address_; + absl::flat_hash_map> + internal_listener_map_by_address_; + std::atomic num_handler_connections_{}; bool disable_listeners_; UnitFloat listener_reject_fraction_{UnitFloat::min()}; diff --git a/test/extensions/common/proxy_protocol/proxy_protocol_regression_test.cc b/test/extensions/common/proxy_protocol/proxy_protocol_regression_test.cc index 0d906a0b79287..7f961c4783b89 100644 --- a/test/extensions/common/proxy_protocol/proxy_protocol_regression_test.cc +++ b/test/extensions/common/proxy_protocol/proxy_protocol_regression_test.cc @@ -48,7 +48,7 @@ class ProxyProtocolRegressionTest : public testing::TestWithParamconnectionInfoProvider().localAddress())); + .WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress())); EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_)); connection_handler_->addListener(absl::nullopt, *this); conn_ = dispatcher_->createClientConnection(socket_->connectionInfoProvider().localAddress(), diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc index 941c474ece02a..7f99111dc1cbc 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc @@ -62,7 +62,7 @@ class ProxyProtocolTest : public testing::TestWithParamconnectionInfoProvider().localAddress())); + .WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress())); EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_)); connection_handler_->addListener(absl::nullopt, *this); conn_ = dispatcher_->createClientConnection(socket_->connectionInfoProvider().localAddress(), @@ -1369,7 +1369,7 @@ class WildcardProxyProtocolTest : public testing::TestWithParamconnectionInfoProvider().localAddress())); + .WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress())); EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_)); connection_handler_->addListener(absl::nullopt, *this); conn_ = dispatcher_->createClientConnection(local_dst_address_, diff --git a/test/server/connection_handler_test.cc b/test/server/connection_handler_test.cc index d2b733072283b..2ad3a517cd469 100644 --- a/test/server/connection_handler_test.cc +++ b/test/server/connection_handler_test.cc @@ -361,7 +361,8 @@ TEST_F(ConnectionHandlerTest, RemoveListenerDuringRebalance) { TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks, connection_balancer, ¤t_handler); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); // Fake a balancer posting a connection to us. @@ -486,7 +487,8 @@ TEST_F(ConnectionHandlerTest, RemoveListener) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); Network::MockConnectionSocket* connection = new NiceMock(); @@ -500,7 +502,6 @@ TEST_F(ConnectionHandlerTest, RemoveListener) { EXPECT_CALL(*listener, onDestroy()); handler_->stopListeners(1); - EXPECT_CALL(dispatcher_, clearDeferredDeleteList()); handler_->removeListeners(1); EXPECT_EQ(0UL, handler_->numConnections()); @@ -517,7 +518,8 @@ TEST_F(ConnectionHandlerTest, DisableListener) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, false, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); EXPECT_CALL(*listener, disable()); @@ -526,6 +528,28 @@ TEST_F(ConnectionHandlerTest, DisableListener) { handler_->disableListeners(); } +// Envoy doesn't have such case yet, just ensure the code won't break with it. +TEST_F(ConnectionHandlerTest, StopAndDisableStoppedListener) { + InSequence s; + + Network::TcpListenerCallbacks* listener_callbacks; + auto listener = new NiceMock(); + TestListener* test_listener = + addListener(1, false, false, "test_listener", listener, &listener_callbacks); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); + handler_->addListener(absl::nullopt, *test_listener); + + EXPECT_CALL(*listener, onDestroy()); + handler_->stopListeners(1); + + // Test stop a stopped listener. + handler_->stopListeners(1); + + // Test disable a stopped listener. + handler_->disableListeners(); +} + TEST_F(ConnectionHandlerTest, AddDisabledListener) { InSequence s; @@ -534,7 +558,8 @@ TEST_F(ConnectionHandlerTest, AddDisabledListener) { TestListener* test_listener = addListener(1, false, false, "test_listener", listener, &listener_callbacks); EXPECT_CALL(*listener, disable()); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); EXPECT_CALL(*listener, onDestroy()); handler_->disableListeners(); @@ -548,7 +573,8 @@ TEST_F(ConnectionHandlerTest, SetListenerRejectFraction) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, false, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); EXPECT_CALL(*listener, setRejectFraction(UnitFloat(0.1234f))); @@ -565,7 +591,8 @@ TEST_F(ConnectionHandlerTest, AddListenerSetRejectFraction) { TestListener* test_listener = addListener(1, false, false, "test_listener", listener, &listener_callbacks); EXPECT_CALL(*listener, setRejectFraction(UnitFloat(0.12345f))); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); EXPECT_CALL(*listener, onDestroy()); handler_->setListenerRejectFraction(UnitFloat(0.12345f)); @@ -580,7 +607,8 @@ TEST_F(ConnectionHandlerTest, SetsTransportSocketConnectTimeout) { TestListener* test_listener = addListener(1, false, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); auto server_connection = new NiceMock(); @@ -605,7 +633,8 @@ TEST_F(ConnectionHandlerTest, DestroyCloseConnections) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); Network::MockConnectionSocket* connection = new NiceMock(); @@ -625,7 +654,8 @@ TEST_F(ConnectionHandlerTest, CloseDuringFilterChainCreate) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(filter_chain_.get())); @@ -649,7 +679,8 @@ TEST_F(ConnectionHandlerTest, CloseConnectionOnEmptyFilterChain) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(filter_chain_.get())); @@ -680,7 +711,7 @@ TEST_F(ConnectionHandlerTest, NormalRedirect) { Network::TcpListenerCallbacks* listener_callbacks2; auto listener2 = new NiceMock(); TestListener* test_listener2 = - addListener(1, false, false, "test_listener2", listener2, &listener_callbacks2); + addListener(2, false, false, "test_listener2", listener2, &listener_callbacks2); Network::Address::InstanceConstSharedPtr alt_address( new Network::Address::Ipv4Instance("127.0.0.2", 20002)); EXPECT_CALL(test_listener2->socket_factory_, localAddress()) @@ -733,6 +764,78 @@ TEST_F(ConnectionHandlerTest, NormalRedirect) { EXPECT_CALL(*listener1, onDestroy()); } +// When update a listener, the old listener will be stopped and the new listener will +// be added into ConnectionHandler before remove the old listener from ConnectionHandler. +// This test ensure ConnectionHandler can query the correct Listener when balanced the connection +// through `getBalancedHandlerByAddress` +TEST_F(ConnectionHandlerTest, MatchLatestListener) { + Network::TcpListenerCallbacks* listener_callbacks; + // The Listener1 will accept the new connection first then balanced to other listener. + auto listener1 = new NiceMock(); + TestListener* test_listener1 = + addListener(1, true, true, "test_listener1", listener1, &listener_callbacks); + EXPECT_CALL(test_listener1->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); + handler_->addListener(absl::nullopt, *test_listener1); + + // Listener2 will be replaced by Listener3. + auto listener2 = new NiceMock(); + TestListener* test_listener2 = addListener(2, false, false, "test_listener2", listener2); + Network::Address::InstanceConstSharedPtr listener2_address( + new Network::Address::Ipv4Instance("127.0.0.1", 10002)); + EXPECT_CALL(test_listener2->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(listener2_address)); + handler_->addListener(absl::nullopt, *test_listener2); + + // Listener3 will replace the listener2. + auto listener3 = new NiceMock(); + TestListener* test_listener3 = addListener(3, false, false, "test_listener3", listener3); + Network::Address::InstanceConstSharedPtr listener3_address( + new Network::Address::Ipv4Instance("127.0.0.1", 10002)); + EXPECT_CALL(test_listener3->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(listener3_address)); + + // This emulated the case of update listener in-place. Stop the old listener and + // add the new listener. + EXPECT_CALL(*listener2, onDestroy()); + handler_->stopListeners(2); + handler_->addListener(absl::nullopt, *test_listener3); + + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + EXPECT_CALL(*test_filter, destroy_()); + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + bool redirected = false; + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + // Insert the Mock filter. + if (!redirected) { + manager.addAcceptFilter(listener_filter_matcher_, + Network::ListenerFilterPtr{test_filter}); + redirected = true; + } + return true; + })); + // This is the address of listener2 and listener3. + Network::Address::InstanceConstSharedPtr alt_address( + new Network::Address::Ipv4Instance("127.0.0.1", 10002, nullptr)); + EXPECT_CALL(*test_filter, onAccept(_)) + .WillOnce(Invoke([&](Network::ListenerFilterCallbacks& cb) -> Network::FilterStatus { + cb.socket().connectionInfoProvider().restoreLocalAddress(alt_address); + return Network::FilterStatus::Continue; + })); + EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(filter_chain_.get())); + + auto* connection = new NiceMock(); + EXPECT_CALL(dispatcher_, createServerConnection_()).WillOnce(Return(connection)); + EXPECT_CALL(factory_, createNetworkFilterChain(_, _)).WillOnce(Return(true)); + listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}); + EXPECT_EQ(1UL, handler_->numConnections()); + + EXPECT_CALL(*listener3, onDestroy()); + EXPECT_CALL(*listener1, onDestroy()); + EXPECT_CALL(*access_log_, log(_, _, _, _)); +} + TEST_F(ConnectionHandlerTest, FallbackToWildcardListener) { Network::TcpListenerCallbacks* listener_callbacks1; auto listener1 = new NiceMock(); @@ -747,7 +850,7 @@ TEST_F(ConnectionHandlerTest, FallbackToWildcardListener) { Network::TcpListenerCallbacks* listener_callbacks2; auto listener2 = new NiceMock(); TestListener* test_listener2 = - addListener(1, false, false, "test_listener2", listener2, &listener_callbacks2); + addListener(2, false, false, "test_listener2", listener2, &listener_callbacks2); Network::Address::InstanceConstSharedPtr any_address = Network::Utility::getIpv4AnyAddress(); EXPECT_CALL(test_listener2->socket_factory_, localAddress()) .WillRepeatedly(ReturnRef(any_address)); @@ -787,7 +890,7 @@ TEST_F(ConnectionHandlerTest, FallbackToWildcardListener) { EXPECT_CALL(*access_log_, log(_, _, _, _)); } -TEST_F(ConnectionHandlerTest, OldBehaviorMatchFirstWildcardListener) { +TEST_F(ConnectionHandlerTest, OldBehaviorWildcardListener) { auto scoped_runtime = std::make_unique(); Runtime::LoaderSingleton::getExisting()->mergeValues( @@ -808,7 +911,7 @@ TEST_F(ConnectionHandlerTest, OldBehaviorMatchFirstWildcardListener) { Network::TcpListenerCallbacks* ipv4_any_listener_callbacks; auto listener2 = new NiceMock(); TestListener* ipv4_any_listener = - addListener(1, false, false, "ipv4_any_test_listener", listener2, + addListener(2, false, false, "ipv4_any_test_listener", listener2, &ipv4_any_listener_callbacks, nullptr, nullptr, Network::Socket::Type::Stream, std::chrono::milliseconds(15000), false, ipv4_overridden_filter_chain_manager); Network::Address::InstanceConstSharedPtr any_address( @@ -817,20 +920,6 @@ TEST_F(ConnectionHandlerTest, OldBehaviorMatchFirstWildcardListener) { .WillRepeatedly(ReturnRef(any_address)); handler_->addListener(absl::nullopt, *ipv4_any_listener); - auto ipv6_overridden_filter_chain_manager = - std::make_shared>(); - Network::TcpListenerCallbacks* ipv6_any_listener_callbacks; - auto listener3 = new NiceMock(); - TestListener* ipv6_any_listener = - addListener(1, false, false, "ipv6_any_test_listener", listener3, - &ipv6_any_listener_callbacks, nullptr, nullptr, Network::Socket::Type::Stream, - std::chrono::milliseconds(15000), false, ipv6_overridden_filter_chain_manager); - Network::Address::InstanceConstSharedPtr any_address_ipv6( - new Network::Address::Ipv6Instance("::", 80)); - EXPECT_CALL(ipv6_any_listener->socket_factory_, localAddress()) - .WillRepeatedly(ReturnRef(any_address_ipv6)); - handler_->addListener(absl::nullopt, *ipv6_any_listener); - Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); EXPECT_CALL(*test_filter, destroy_()); Network::MockConnectionSocket* accepted_socket = new NiceMock(); @@ -856,14 +945,12 @@ TEST_F(ConnectionHandlerTest, OldBehaviorMatchFirstWildcardListener) { EXPECT_CALL(manager_, findFilterChain(_)).Times(0); EXPECT_CALL(*ipv4_overridden_filter_chain_manager, findFilterChain(_)) .WillOnce(Return(filter_chain_.get())); - EXPECT_CALL(*ipv6_overridden_filter_chain_manager, findFilterChain(_)).Times(0); auto* connection = new NiceMock(); EXPECT_CALL(dispatcher_, createServerConnection_()).WillOnce(Return(connection)); EXPECT_CALL(factory_, createNetworkFilterChain(_, _)).WillOnce(Return(true)); listener_callbacks1->onAccept(Network::ConnectionSocketPtr{accepted_socket}); EXPECT_EQ(1UL, handler_->numConnections()); - EXPECT_CALL(*listener3, onDestroy()); EXPECT_CALL(*listener2, onDestroy()); EXPECT_CALL(*listener1, onDestroy()); EXPECT_CALL(*access_log_, log(_, _, _, _)); @@ -887,7 +974,7 @@ TEST_F(ConnectionHandlerTest, MatchIPv6WildcardListener) { Network::TcpListenerCallbacks* ipv4_any_listener_callbacks; auto listener2 = new NiceMock(); TestListener* ipv4_any_listener = - addListener(1, false, false, "ipv4_any_test_listener", listener2, + addListener(2, false, false, "ipv4_any_test_listener", listener2, &ipv4_any_listener_callbacks, nullptr, nullptr, Network::Socket::Type::Stream, std::chrono::milliseconds(15000), false, ipv4_overridden_filter_chain_manager); @@ -902,7 +989,7 @@ TEST_F(ConnectionHandlerTest, MatchIPv6WildcardListener) { Network::TcpListenerCallbacks* ipv6_any_listener_callbacks; auto listener3 = new NiceMock(); TestListener* ipv6_any_listener = - addListener(1, false, false, "ipv6_any_test_listener", listener3, + addListener(3, false, false, "ipv6_any_test_listener", listener3, &ipv6_any_listener_callbacks, nullptr, nullptr, Network::Socket::Type::Stream, std::chrono::milliseconds(15000), false, ipv6_overridden_filter_chain_manager); Network::Address::InstanceConstSharedPtr any_address_ipv6( @@ -1315,7 +1402,7 @@ TEST_F(ConnectionHandlerTest, TcpListenerInplaceUpdate) { addListener(old_listener_tag, true, false, "test_listener", old_listener, &old_listener_callbacks, mock_connection_balancer, ¤t_handler); EXPECT_CALL(old_test_listener->socket_factory_, localAddress()) - .WillOnce(ReturnRef(local_address_)); + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *old_test_listener); ASSERT_NE(old_test_listener, nullptr); @@ -1352,7 +1439,8 @@ TEST_F(ConnectionHandlerTest, TcpListenerRemoveFilterChain) { auto listener = new NiceMock(); TestListener* test_listener = addListener(listener_tag, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); Network::MockConnectionSocket* connection = new NiceMock(); @@ -1400,7 +1488,8 @@ TEST_F(ConnectionHandlerTest, TcpListenerRemoveFilterChainCalledAfterListenerIsR auto listener = new NiceMock(); TestListener* test_listener = addListener(listener_tag, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); Network::MockConnectionSocket* connection = new NiceMock(); @@ -1462,7 +1551,8 @@ TEST_F(ConnectionHandlerTest, TcpListenerRemoveListener) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); Network::MockConnectionSocket* connection = new NiceMock(); @@ -1491,7 +1581,8 @@ TEST_F(ConnectionHandlerTest, TcpListenerGlobalCxLimitReject) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); listener_callbacks->onReject(Network::TcpListenerCallbacks::RejectCause::GlobalCxLimit); @@ -1506,7 +1597,8 @@ TEST_F(ConnectionHandlerTest, TcpListenerOverloadActionReject) { auto listener = new NiceMock(); TestListener* test_listener = addListener(1, true, false, "test_listener", listener, &listener_callbacks); - EXPECT_CALL(test_listener->socket_factory_, localAddress()).WillOnce(ReturnRef(local_address_)); + EXPECT_CALL(test_listener->socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener); listener_callbacks->onReject(Network::TcpListenerCallbacks::RejectCause::OverloadAction); @@ -1587,7 +1679,7 @@ TEST_F(ConnectionHandlerTest, DisableInternalListener) { TestListener* internal_listener = addInternalListener(1, "test_internal_listener", std::chrono::milliseconds(), false, nullptr); EXPECT_CALL(internal_listener->socket_factory_, localAddress()) - .WillOnce(ReturnRef(local_address)); + .WillRepeatedly(ReturnRef(local_address)); handler_->addListener(absl::nullopt, *internal_listener); auto internal_listener_cb = handler_->findByAddress(local_address); ASSERT_TRUE(internal_listener_cb.has_value()); @@ -1613,7 +1705,7 @@ TEST_F(ConnectionHandlerTest, InternalListenerInplaceUpdate) { TestListener* internal_listener = addInternalListener( old_listener_tag, "test_internal_listener", std::chrono::milliseconds(), false, nullptr); EXPECT_CALL(internal_listener->socket_factory_, localAddress()) - .WillOnce(ReturnRef(local_address)); + .WillRepeatedly(ReturnRef(local_address)); handler_->addListener(absl::nullopt, *internal_listener); ASSERT_NE(internal_listener, nullptr);