Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions source/common/network/utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,12 @@ Address::InstanceConstSharedPtr Utility::getIpv6LoopbackAddress() {
new Address::Ipv6Instance("::1", 0, nullptr));
}

Address::InstanceConstSharedPtr Utility::getIpv4AnyAddress() {
CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr,
new Address::Ipv4Instance(static_cast<uint32_t>(0)));
Address::InstanceConstSharedPtr Utility::getIpv4AnyAddress(uint32_t port) {
CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr, new Address::Ipv4Instance(port));
}

Address::InstanceConstSharedPtr Utility::getIpv6AnyAddress() {
CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr,
new Address::Ipv6Instance(static_cast<uint32_t>(0)));
Address::InstanceConstSharedPtr Utility::getIpv6AnyAddress(uint32_t port) {
CONSTRUCT_ON_FIRST_USE(Address::InstanceConstSharedPtr, new Address::Ipv6Instance(port));
}

const std::string& Utility::getIpv4CidrCatchAllAddress() {
Expand Down
6 changes: 4 additions & 2 deletions source/common/network/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,20 @@ class Utility {
static Address::InstanceConstSharedPtr getIpv6LoopbackAddress();

/**
* @param port to be included in address, the default is 0.
* @return Address::InstanceConstSharedPtr an address that represents the IPv4 wildcard address
* (i.e. "0.0.0.0"). Used during binding to indicate that incoming connections to any
* local IPv4 address are to be accepted.
*/
static Address::InstanceConstSharedPtr getIpv4AnyAddress();
static Address::InstanceConstSharedPtr getIpv4AnyAddress(uint32_t port = 0);

/**
* @param port to be included in address, the default is 0.
* @return Address::InstanceConstSharedPtr an address that represents the IPv6 wildcard address
* (i.e. "::"). Used during binding to indicate that incoming connections to any local
* IPv6 address are to be accepted.
*/
static Address::InstanceConstSharedPtr getIpv6AnyAddress();
static Address::InstanceConstSharedPtr getIpv6AnyAddress(uint32_t port = 0);

/**
* @return the IPv4 CIDR catch-all address (0.0.0.0/0).
Expand Down
199 changes: 100 additions & 99 deletions source/server/connection_handler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,60 +39,84 @@ void ConnectionHandlerImpl::addListener(absl::optional<uint64_t> overridden_list
return;
}

ActiveListenerDetails details;
auto details = std::make_shared<ActiveListenerDetails>();
if (config.internalListenerConfig().has_value()) {
if (overridden_listener.has_value()) {
for (auto& listener : listeners_) {
if (listener.second.listener_->listenerTag() == overridden_listener) {
listener.second.internalListener()->get().updateListenerConfig(config);
return;
}
if (auto iter = listener_map_by_tag_.find(overridden_listener.value());
iter != listener_map_by_tag_.end()) {
iter->second->internalListener()->get().updateListenerConfig(config);
return;
}
NOT_REACHED_GCOVR_EXCL_LINE;
}
auto internal_listener = std::make_unique<ActiveInternalListener>(*this, dispatcher(), config);
details.typed_listener_ = *internal_listener;
details.listener_ = std::move(internal_listener);
details->typed_listener_ = *internal_listener;
details->listener_ = std::move(internal_listener);
} else if (config.listenSocketFactory().socketType() == Network::Socket::Type::Stream) {
if (!support_udp_in_place_filter_chain_update && overridden_listener.has_value()) {
for (auto& listener : listeners_) {
if (listener.second.listener_->listenerTag() == overridden_listener) {
listener.second.tcpListener()->get().updateListenerConfig(config);
return;
}
if (auto iter = listener_map_by_tag_.find(overridden_listener.value());
iter != listener_map_by_tag_.end()) {
iter->second->tcpListener()->get().updateListenerConfig(config);
return;
}
NOT_REACHED_GCOVR_EXCL_LINE;
}
// worker_index_ doesn't have a value on the main thread for the admin server.
auto tcp_listener = std::make_unique<ActiveTcpListener>(
*this, config, worker_index_.has_value() ? *worker_index_ : 0);
details.typed_listener_ = *tcp_listener;
details.listener_ = std::move(tcp_listener);
details->typed_listener_ = *tcp_listener;
details->listener_ = std::move(tcp_listener);
} else {
ASSERT(config.udpListenerConfig().has_value(), "UDP listener factory is not initialized.");
ASSERT(worker_index_.has_value());
ConnectionHandler::ActiveUdpListenerPtr udp_listener =
config.udpListenerConfig()->listenerFactory().createActiveUdpListener(*worker_index_, *this,
dispatcher_, config);
details.typed_listener_ = *udp_listener;
details.listener_ = std::move(udp_listener);
details->typed_listener_ = *udp_listener;
details->listener_ = std::move(udp_listener);
}

if (disable_listeners_) {
details.listener_->pauseListening();
details->listener_->pauseListening();
}
if (auto* listener = details.listener_->listener(); listener != nullptr) {
if (auto* listener = details->listener_->listener(); listener != nullptr) {
listener->setRejectFraction(listener_reject_fraction_);
}
listeners_.emplace_back(config.listenSocketFactory().localAddress(), std::move(details));

details->listener_tag_ = config.listenerTag();
details->address_ = config.listenSocketFactory().localAddress();

ASSERT(!listener_map_by_tag_.contains(config.listenerTag()));

listener_map_by_tag_.emplace(config.listenerTag(), details);
// This map only store the new listener.
if (absl::holds_alternative<std::reference_wrapper<ActiveTcpListener>>(
details->typed_listener_)) {
tcp_listener_map_by_address_.insert_or_assign(
config.listenSocketFactory().localAddress()->asStringView(), details);
} else if (absl::holds_alternative<std::reference_wrapper<ActiveInternalListener>>(
details->typed_listener_)) {
internal_listener_map_by_address_.insert_or_assign(
config.listenSocketFactory().localAddress()->asStringView(), details);
}
}

void ConnectionHandlerImpl::removeListeners(uint64_t listener_tag) {
for (auto listener = listeners_.begin(); listener != listeners_.end();) {
if (listener->second.listener_->listenerTag() == listener_tag) {
listener = listeners_.erase(listener);
} else {
++listener;
if (auto listener_iter = listener_map_by_tag_.find(listener_tag);
listener_iter != listener_map_by_tag_.end()) {
// listener_map_by_address_ may already update to the new listener. Compare it with the one
// which find from listener_map_by_tag_, only delete it when it is same listener.
auto address_view = listener_iter->second->address_->asStringView();
if (tcp_listener_map_by_address_.contains(address_view) &&
tcp_listener_map_by_address_[address_view]->listener_tag_ ==
listener_iter->second->listener_tag_) {
tcp_listener_map_by_address_.erase(address_view);
} else if (internal_listener_map_by_address_.contains(address_view) &&
internal_listener_map_by_address_[address_view]->listener_tag_ ==
listener_iter->second->listener_tag_) {
internal_listener_map_by_address_.erase(address_view);
}
listener_map_by_tag_.erase(listener_iter);
}
}

Expand All @@ -112,68 +136,66 @@ ConnectionHandlerImpl::getUdpListenerCallbacks(uint64_t listener_tag) {
void ConnectionHandlerImpl::removeFilterChains(
uint64_t listener_tag, const std::list<const Network::FilterChain*>& filter_chains,
std::function<void()> completion) {
for (auto& listener : listeners_) {
if (listener.second.listener_->listenerTag() == listener_tag) {
listener.second.listener_->onFilterChainDraining(filter_chains);
break;
}
if (auto listener_it = listener_map_by_tag_.find(listener_tag);
listener_it != listener_map_by_tag_.end()) {
listener_it->second->listener_->onFilterChainDraining(filter_chains);
}

// Reach here if the target listener is found or the target listener was removed by a full
// listener update. In either case, the completion must be deferred so that any active connection
// referencing the filter chain can finish prior to deletion.
Event::DeferredTaskUtil::deferredRun(dispatcher_, std::move(completion));
}

void ConnectionHandlerImpl::stopListeners(uint64_t listener_tag) {
for (auto& listener : listeners_) {
if (listener.second.listener_->listenerTag() == listener_tag) {
listener.second.listener_->shutdownListener();
if (auto iter = listener_map_by_tag_.find(listener_tag); iter != listener_map_by_tag_.end()) {
if (iter->second->listener_->listener() != nullptr) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can it be in this map, and have a nullptr listener?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will stop the listener before draining it, the shutdownListener() https://github.com/envoyproxy/envoy/pull/19362/files#diff-fe538327af200510b78cadc759f2c3af989c822a370d3f8a005b92bc59568aa5R153 will reset the pointer. Then it is going to be nullptr. That means if we shutdown twice, then the second time you will see a nullptr.

iter->second->listener_->shutdownListener();
}
}
}

void ConnectionHandlerImpl::stopListeners() {
for (auto& listener : listeners_) {
listener.second.listener_->shutdownListener();
for (auto& iter : listener_map_by_tag_) {
if (iter.second->listener_->listener() != nullptr) {
iter.second->listener_->shutdownListener();
}
}
}

void ConnectionHandlerImpl::disableListeners() {
disable_listeners_ = true;
for (auto& listener : listeners_) {
listener.second.listener_->pauseListening();
for (auto& iter : listener_map_by_tag_) {
if (iter.second->listener_->listener() != nullptr) {
iter.second->listener_->pauseListening();
}
}
}

void ConnectionHandlerImpl::enableListeners() {
disable_listeners_ = false;
for (auto& listener : listeners_) {
listener.second.listener_->resumeListening();
for (auto& iter : listener_map_by_tag_) {
if (iter.second->listener_->listener() != nullptr) {
iter.second->listener_->resumeListening();
}
}
}

void ConnectionHandlerImpl::setListenerRejectFraction(UnitFloat reject_fraction) {
listener_reject_fraction_ = reject_fraction;
for (auto& listener : listeners_) {
listener.second.listener_->listener()->setRejectFraction(reject_fraction);
for (auto& iter : listener_map_by_tag_) {
if (iter.second->listener_->listener() != nullptr) {
iter.second->listener_->listener()->setRejectFraction(reject_fraction);
}
}
}

Network::InternalListenerOptRef
ConnectionHandlerImpl::findByAddress(const Network::Address::InstanceConstSharedPtr& address) {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking to change this method name as findInternalListenerByAddress, since the assertion as below, this method is only used for internal listener.

ASSERT(address->type() == Network::Address::Type::EnvoyInternal);
auto listener_it =
std::find_if(listeners_.begin(), listeners_.end(),
[&address](std::pair<Network::Address::InstanceConstSharedPtr,
ConnectionHandlerImpl::ActiveListenerDetails>& p) {
return p.second.internalListener().has_value() &&
p.second.listener_->listener() != nullptr &&
p.first->type() == Network::Address::Type::EnvoyInternal &&
*(p.first) == *address;
});

if (listener_it != listeners_.end()) {
return Network::InternalListenerOptRef(listener_it->second.internalListener().value().get());
if (auto listener_it = internal_listener_map_by_address_.find(address->asStringView());
listener_it != internal_listener_map_by_address_.end()) {
return Network::InternalListenerOptRef(listener_it->second->internalListener().value().get());
}
return OptRef<Network::InternalListener>();
}
Expand All @@ -198,15 +220,9 @@ ConnectionHandlerImpl::ActiveListenerDetails::internalListener() {

ConnectionHandlerImpl::ActiveListenerDetailsOptRef
ConnectionHandlerImpl::findActiveListenerByTag(uint64_t listener_tag) {
// TODO(mattklein123): We should probably use a hash table here to lookup the tag
// instead of iterating through the listener list.
for (auto& listener : listeners_) {
if (listener.second.listener_->listener() != nullptr &&
listener.second.listener_->listenerTag() == listener_tag) {
return listener.second;
}
if (auto iter = listener_map_by_tag_.find(listener_tag); iter != listener_map_by_tag_.end()) {
return *iter->second;
}

return absl::nullopt;
}

Expand All @@ -224,58 +240,43 @@ ConnectionHandlerImpl::getBalancedHandlerByTag(uint64_t listener_tag) {

Network::BalancedConnectionHandlerOptRef
ConnectionHandlerImpl::getBalancedHandlerByAddress(const Network::Address::Instance& address) {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also think about change this method name to getBalancedTCPHandlerByAddress, since this method is only used for TCP listener

// This is a linear operation, may need to add a map<address, listener> to improve performance.
// However, linear performance might be adequate since the number of listeners is small.
// We do not return stopped listeners.
auto listener_it =
std::find_if(listeners_.begin(), listeners_.end(),
[&address](std::pair<Network::Address::InstanceConstSharedPtr,
ConnectionHandlerImpl::ActiveListenerDetails>& p) {
return p.second.tcpListener().has_value() &&
p.second.listener_->listener() != nullptr &&
p.first->type() == Network::Address::Type::Ip && *(p.first) == address;
});

// If there is exact address match, return the corresponding listener.
if (listener_it != listeners_.end()) {
if (auto listener_it = tcp_listener_map_by_address_.find(address.asStringView());
listener_it != tcp_listener_map_by_address_.end()) {
return Network::BalancedConnectionHandlerOptRef(
listener_it->second.tcpListener().value().get());
listener_it->second->tcpListener().value().get());
}

OptRef<ConnectionHandlerImpl::ActiveListenerDetails> details;
// Otherwise, we need to look for the wild card match, i.e., 0.0.0.0:[address_port].
// We do not return stopped listeners.
// TODO(wattli): consolidate with previous search for more efficiency.
if (Runtime::runtimeFeatureEnabled(
"envoy.reloadable_features.listener_wildcard_match_ip_family")) {
listener_it =
std::find_if(listeners_.begin(), listeners_.end(),
[&address](const std::pair<Network::Address::InstanceConstSharedPtr,
ConnectionHandlerImpl::ActiveListenerDetails>& p) {
return absl::holds_alternative<std::reference_wrapper<ActiveTcpListener>>(
p.second.typed_listener_) &&
p.second.listener_->listener() != nullptr &&
p.first->type() == Network::Address::Type::Ip &&
p.first->ip()->port() == address.ip()->port() &&
p.first->ip()->isAnyAddress() &&
p.first->ip()->version() == address.ip()->version();
});
std::string addr_str =
address.ip()->version() == Network::Address::IpVersion::v4
? Network::Utility::getIpv4AnyAddress(address.ip()->port())->asString()
: Network::Utility::getIpv6AnyAddress(address.ip()->port())->asString();

auto iter = tcp_listener_map_by_address_.find(addr_str);
if (iter != tcp_listener_map_by_address_.end()) {
details = *iter->second;
}
} else {
listener_it =
std::find_if(listeners_.begin(), listeners_.end(),
[&address](const std::pair<Network::Address::InstanceConstSharedPtr,
ConnectionHandlerImpl::ActiveListenerDetails>& p) {
return absl::holds_alternative<std::reference_wrapper<ActiveTcpListener>>(
p.second.typed_listener_) &&
p.second.listener_->listener() != nullptr &&
p.first->type() == Network::Address::Type::Ip &&
p.first->ip()->port() == address.ip()->port() &&
p.first->ip()->isAnyAddress();
});
for (auto& iter : tcp_listener_map_by_address_) {
if (iter.second->listener_->listener() != nullptr &&
iter.second->address_->type() == Network::Address::Type::Ip &&
iter.second->address_->ip()->port() == address.ip()->port() &&
iter.second->address_->ip()->isAnyAddress()) {
details = *iter.second;
}
}
}
return (listener_it != listeners_.end())
return (details.has_value())
? Network::BalancedConnectionHandlerOptRef(
ActiveTcpListenerOptRef(absl::get<std::reference_wrapper<ActiveTcpListener>>(
listener_it->second.typed_listener_))
details->typed_listener_))
.value()
.get())
: absl::nullopt;
Expand Down
9 changes: 8 additions & 1 deletion source/server/connection_handler_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class ConnectionHandlerImpl : public Network::TcpConnectionHandler,
struct ActiveListenerDetails {
// Strong pointer to the listener, whether TCP, UDP, QUIC, etc.
Network::ConnectionHandler::ActiveListenerPtr listener_;
Network::Address::InstanceConstSharedPtr address_;
uint64_t listener_tag_;

absl::variant<absl::monostate, std::reference_wrapper<ActiveTcpListener>,
std::reference_wrapper<Network::UdpListenerCallbacks>,
Expand All @@ -93,7 +95,12 @@ class ConnectionHandlerImpl : public Network::TcpConnectionHandler,
const absl::optional<uint32_t> worker_index_;
Event::Dispatcher& dispatcher_;
const std::string per_handler_stat_prefix_;
std::list<std::pair<Network::Address::InstanceConstSharedPtr, ActiveListenerDetails>> listeners_;
absl::flat_hash_map<uint64_t, std::shared_ptr<ActiveListenerDetails>> listener_map_by_tag_;
absl::flat_hash_map<std::string, std::shared_ptr<ActiveListenerDetails>>
tcp_listener_map_by_address_;
absl::flat_hash_map<std::string, std::shared_ptr<ActiveListenerDetails>>
internal_listener_map_by_address_;

std::atomic<uint64_t> num_handler_connections_{};
bool disable_listeners_;
UnitFloat listener_reject_fraction_{UnitFloat::min()};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ProxyProtocolRegressionTest : public testing::TestWithParam<Network::Addre
init_manager_(nullptr) {
EXPECT_CALL(socket_factory_, socketType()).WillOnce(Return(Network::Socket::Type::Stream));
EXPECT_CALL(socket_factory_, localAddress())
.WillOnce(ReturnRef(socket_->connectionInfoProvider().localAddress()));
.WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress()));
EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_));
connection_handler_->addListener(absl::nullopt, *this);
conn_ = dispatcher_->createClientConnection(socket_->connectionInfoProvider().localAddress(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ProxyProtocolTest : public testing::TestWithParam<Network::Address::IpVers
init_manager_(nullptr) {
EXPECT_CALL(socket_factory_, socketType()).WillOnce(Return(Network::Socket::Type::Stream));
EXPECT_CALL(socket_factory_, localAddress())
.WillOnce(ReturnRef(socket_->connectionInfoProvider().localAddress()));
.WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress()));
EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_));
connection_handler_->addListener(absl::nullopt, *this);
conn_ = dispatcher_->createClientConnection(socket_->connectionInfoProvider().localAddress(),
Expand Down Expand Up @@ -1369,7 +1369,7 @@ class WildcardProxyProtocolTest : public testing::TestWithParam<Network::Address
init_manager_(nullptr) {
EXPECT_CALL(socket_factory_, socketType()).WillOnce(Return(Network::Socket::Type::Stream));
EXPECT_CALL(socket_factory_, localAddress())
.WillOnce(ReturnRef(socket_->connectionInfoProvider().localAddress()));
.WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress()));
EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_));
connection_handler_->addListener(absl::nullopt, *this);
conn_ = dispatcher_->createClientConnection(local_dst_address_,
Expand Down
Loading