diff --git a/source/common/network/listen_socket_impl.cc b/source/common/network/listen_socket_impl.cc index 674cf456bce06..bd8a8b8bb0a3e 100644 --- a/source/common/network/listen_socket_impl.cc +++ b/source/common/network/listen_socket_impl.cc @@ -48,20 +48,16 @@ void ListenSocketImpl::setupSocket(const Network::Socket::OptionsSharedPtr& opti } } +// UDP listen socket desires io handle regardless bind_to_port is true or false. template <> -void NetworkListenSocket>::setPrebindSocketOptions() { -// On Windows, SO_REUSEADDR does not restrict subsequent bind calls when there is a listener as on -// Linux and later BSD socket stacks -#ifndef WIN32 - int on = 1; - auto status = setSocketOption(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - RELEASE_ASSERT(status.rc_ != -1, "failed to set SO_REUSEADDR socket option"); -#endif +NetworkListenSocket>::NetworkListenSocket( + const Address::InstanceConstSharedPtr& address, + const Network::Socket::OptionsSharedPtr& options, bool bind_to_port) + : ListenSocketImpl(Network::ioHandleForAddr(Socket::Type::Datagram, address), address) { + setPrebindSocketOptions(); + setupSocket(options, bind_to_port); } -template <> -void NetworkListenSocket>::setPrebindSocketOptions() {} - UdsListenSocket::UdsListenSocket(const Address::InstanceConstSharedPtr& address) : ListenSocketImpl(ioHandleForAddr(Socket::Type::Stream, address), address) { RELEASE_ASSERT(io_handle_->isOpen(), ""); diff --git a/source/common/network/listen_socket_impl.h b/source/common/network/listen_socket_impl.h index 663c372c6ee89..4a98916115317 100644 --- a/source/common/network/listen_socket_impl.h +++ b/source/common/network/listen_socket_impl.h @@ -13,6 +13,7 @@ #include "common/common/assert.h" #include "common/common/dump_state_utils.h" #include "common/network/socket_impl.h" +#include "common/network/socket_interface.h" namespace Envoy { namespace Network { @@ -25,7 +26,8 @@ class ListenSocketImpl : public SocketImpl { SocketPtr duplicate() override { // Using `new` to access a non-public constructor. return absl::WrapUnique( - new ListenSocketImpl(io_handle_->duplicate(), address_provider_->localAddress())); + new ListenSocketImpl(io_handle_ == nullptr ? nullptr : io_handle_->duplicate(), + address_provider_->localAddress())); } void setupSocket(const Network::Socket::OptionsSharedPtr& options, bool bind_to_port); @@ -50,11 +52,23 @@ template class NetworkListenSocket : public ListenSocketImpl { public: NetworkListenSocket(const Address::InstanceConstSharedPtr& address, const Network::Socket::OptionsSharedPtr& options, bool bind_to_port) - : ListenSocketImpl(Network::ioHandleForAddr(T::type, address), address) { - RELEASE_ASSERT(io_handle_->isOpen(), ""); - - setPrebindSocketOptions(); - + : ListenSocketImpl(bind_to_port ? Network::ioHandleForAddr(T::type, address) : nullptr, + address) { + // Prebind is applied if the socket is bind to port. + if (bind_to_port) { + RELEASE_ASSERT(io_handle_->isOpen(), ""); + setPrebindSocketOptions(); + } else { + // If the tcp listener does not bind to port, we test that the ip family is supported. + if (auto ip = address->ip(); ip != nullptr) { + RELEASE_ASSERT( + Network::SocketInterfaceSingleton::get().ipFamilySupported(ip->ipv4() ? AF_INET + : AF_INET6), + fmt::format( + "Creating listen socket address {} but the address familiy is not supported", + address->asStringView())); + } + } setupSocket(options, bind_to_port); } @@ -67,9 +81,30 @@ template class NetworkListenSocket : public ListenSocketImpl { Socket::Type socketType() const override { return T::type; } protected: - void setPrebindSocketOptions(); + void setPrebindSocketOptions() { + // On Windows, SO_REUSEADDR does not restrict subsequent bind calls when there is a listener as + // on Linux and later BSD socket stacks +#ifndef WIN32 + int on = 1; + auto status = setSocketOption(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + RELEASE_ASSERT(status.rc_ != -1, "failed to set SO_REUSEADDR socket option"); +#endif + } }; +template <> +inline void +NetworkListenSocket>::setPrebindSocketOptions() {} + +// UDP listen socket desires io handle regardless bind_to_port is true or false. +template <> +NetworkListenSocket>::NetworkListenSocket( + const Address::InstanceConstSharedPtr& address, + const Network::Socket::OptionsSharedPtr& options, bool bind_to_port); + +template class NetworkListenSocket>; +template class NetworkListenSocket>; + using TcpListenSocket = NetworkListenSocket>; using TcpListenSocketPtr = std::unique_ptr; diff --git a/source/common/network/tcp_listener_impl.cc b/source/common/network/tcp_listener_impl.cc index 38c77db31e441..09d56d566074f 100644 --- a/source/common/network/tcp_listener_impl.cc +++ b/source/common/network/tcp_listener_impl.cc @@ -42,6 +42,7 @@ bool TcpListenerImpl::rejectCxOverGlobalLimit() { } void TcpListenerImpl::onSocketEvent(short flags) { + ASSERT(bind_to_port_); ASSERT(flags & (Event::FileReadyType::Read)); // TODO(fcoras): Add limit on number of accepted calls per wakeup @@ -95,6 +96,8 @@ void TcpListenerImpl::onSocketEvent(short flags) { } void TcpListenerImpl::setupServerSocket(Event::DispatcherImpl& dispatcher, Socket& socket) { + ASSERT(bind_to_port_); + socket.ioHandle().listen(backlog_size_); // Although onSocketEvent drains to completion, use level triggered mode to avoid potential @@ -114,15 +117,27 @@ TcpListenerImpl::TcpListenerImpl(Event::DispatcherImpl& dispatcher, Random::Rand SocketSharedPtr socket, TcpListenerCallbacks& cb, bool bind_to_port, uint32_t backlog_size) : BaseListenerImpl(dispatcher, std::move(socket)), cb_(cb), backlog_size_(backlog_size), - random_(random), reject_fraction_(0.0) { + random_(random), bind_to_port_(bind_to_port), reject_fraction_(0.0) { if (bind_to_port) { setupServerSocket(dispatcher, *socket_); } } -void TcpListenerImpl::enable() { socket_->ioHandle().enableFileEvents(Event::FileReadyType::Read); } +void TcpListenerImpl::enable() { + if (bind_to_port_) { + socket_->ioHandle().enableFileEvents(Event::FileReadyType::Read); + } else { + FANCY_LOG(debug, "The listener cannot be enabled since it's not bind to port."); + } +} -void TcpListenerImpl::disable() { socket_->ioHandle().enableFileEvents(0); } +void TcpListenerImpl::disable() { + if (bind_to_port_) { + socket_->ioHandle().enableFileEvents(0); + } else { + FANCY_LOG(debug, "The listener cannot be disable since it's not bind to port."); + } +} void TcpListenerImpl::setRejectFraction(const UnitFloat reject_fraction) { reject_fraction_ = reject_fraction; diff --git a/source/common/network/tcp_listener_impl.h b/source/common/network/tcp_listener_impl.h index fd66ffc60deea..207b1b95abfc4 100644 --- a/source/common/network/tcp_listener_impl.h +++ b/source/common/network/tcp_listener_impl.h @@ -19,7 +19,11 @@ class TcpListenerImpl : public BaseListenerImpl { TcpListenerImpl(Event::DispatcherImpl& dispatcher, Random::RandomGenerator& random, SocketSharedPtr socket, TcpListenerCallbacks& cb, bool bind_to_port, uint32_t backlog_size); - ~TcpListenerImpl() override { socket_->ioHandle().resetFileEvents(); } + ~TcpListenerImpl() override { + if (bind_to_port_) { + socket_->ioHandle().resetFileEvents(); + } + } void disable() override; void enable() override; void setRejectFraction(UnitFloat reject_fraction) override; @@ -40,6 +44,7 @@ class TcpListenerImpl : public BaseListenerImpl { static bool rejectCxOverGlobalLimit(); Random::RandomGenerator& random_; + bool bind_to_port_; UnitFloat reject_fraction_; }; diff --git a/source/common/singleton/threadsafe_singleton.h b/source/common/singleton/threadsafe_singleton.h index 5b55dc0af5170..bf16f4b8cba73 100644 --- a/source/common/singleton/threadsafe_singleton.h +++ b/source/common/singleton/threadsafe_singleton.h @@ -74,7 +74,7 @@ template T* InjectableSingleton::loader_ = nullptr; template class ScopedInjectableLoader { public: - ScopedInjectableLoader(std::unique_ptr&& instance) { + explicit ScopedInjectableLoader(std::unique_ptr&& instance) { instance_ = std::move(instance); InjectableSingleton::initialize(instance_.get()); } @@ -84,4 +84,24 @@ template class ScopedInjectableLoader { std::unique_ptr instance_; }; +// This class saves the singleton object and restore the original singleton at destroy. This class +// is not thread safe. It can be used in single thread test. +template +class StackedScopedInjectableLoader : + // To access the protected loader_. + protected InjectableSingleton { +public: + explicit StackedScopedInjectableLoader(std::unique_ptr&& instance) { + original_loader_ = InjectableSingleton::getExisting(); + InjectableSingleton::clear(); + instance_ = std::move(instance); + InjectableSingleton::initialize(instance_.get()); + } + ~StackedScopedInjectableLoader() { InjectableSingleton::loader_ = original_loader_; } + +private: + std::unique_ptr instance_; + T* original_loader_; +}; + } // namespace Envoy diff --git a/test/common/network/listen_socket_impl_test.cc b/test/common/network/listen_socket_impl_test.cc index 789441ca571f8..9f50128898d82 100644 --- a/test/common/network/listen_socket_impl_test.cc +++ b/test/common/network/listen_socket_impl_test.cc @@ -1,3 +1,5 @@ +#include + #include "envoy/common/platform.h" #include "envoy/config/core/v3/base.pb.h" #include "envoy/network/exception.h" @@ -6,6 +8,7 @@ #include "common/api/os_sys_calls_impl.h" #include "common/network/io_socket_handle_impl.h" #include "common/network/listen_socket_impl.h" +#include "common/network/socket_interface_impl.h" #include "common/network/utility.h" #include "test/mocks/network/mocks.h" @@ -22,6 +25,18 @@ namespace Envoy { namespace Network { namespace { +class MockSingleFamilySocketInterface : public SocketInterfaceImpl { +public: + explicit MockSingleFamilySocketInterface(Address::IpVersion version) : version_(version) {} + MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, Address::Type, Address::IpVersion, bool), + (const)); + MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, const Address::InstanceConstSharedPtr), (const)); + bool ipFamilySupported(int domain) override { + return (version_ == Address::IpVersion::v4) ? domain == AF_INET : domain == AF_INET6; + } + const Address::IpVersion version_; +}; + TEST(ConnectionSocketImplTest, LowerCaseRequestedServerName) { absl::string_view serverName("www.EXAMPLE.com"); absl::string_view expectedServerName("www.example.com"); @@ -182,6 +197,40 @@ TEST_P(ListenSocketImplTestTcp, CheckIpVersionWithNullLocalAddress) { EXPECT_EQ(Address::IpVersion::v4, socket.ipVersion()); } +TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsdSocketCreated) { + auto mock_interface = std::make_unique(version_); + auto* mock_interface_ptr = mock_interface.get(); + auto any_address = version_ == Address::IpVersion::v4 ? Utility::getIpv4AnyAddress() + : Utility::getIpv6AnyAddress(); + + StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + + { + EXPECT_CALL(*mock_interface_ptr, socket(_, _)).Times(0); + EXPECT_CALL(*mock_interface_ptr, socket(_, _, _, _)).Times(0); + TcpListenSocket virtual_listener_socket(any_address, nullptr, + /*bind_to_port*/ false); + } +} + +TEST_P(ListenSocketImplTestTcp, DeathAtUnSupportedIpFamilyListenSocket) { + auto mock_interface = std::make_unique(version_); + auto* mock_interface_ptr = mock_interface.get(); + auto the_other_address = version_ == Address::IpVersion::v4 ? Utility::getIpv6AnyAddress() + : Utility::getIpv4AnyAddress(); + StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + { + EXPECT_CALL(*mock_interface_ptr, socket(_, _)).Times(0); + EXPECT_CALL(*mock_interface_ptr, socket(_, _, _, _)).Times(0); + EXPECT_DEATH( + { + TcpListenSocket virtual_listener_socket(the_other_address, nullptr, + /*bind_to_port*/ false); + }, + ".*"); + } +} + TEST_P(ListenSocketImplTestUdp, BindSpecificPort) { testBindSpecificPort(); } // Validate that we get port allocation when binding to port zero.