diff --git a/source/common/network/listen_socket_impl.cc b/source/common/network/listen_socket_impl.cc index 4ed599d19b737..6f00446460ee0 100644 --- a/source/common/network/listen_socket_impl.cc +++ b/source/common/network/listen_socket_impl.cc @@ -48,16 +48,6 @@ void ListenSocketImpl::setupSocket(const Network::Socket::OptionsSharedPtr& opti } } -// UDP listen socket desires io handle regardless bind_to_port is true or false. -template <> -NetworkListenSocket>::NetworkListenSocket( - const Address::InstanceConstSharedPtr& address, - const Network::Socket::OptionsSharedPtr& options, bool bind_to_port) - : ListenSocketImpl(Network::ioHandleForAddr(Socket::Type::Datagram, address), address) { - setPrebindSocketOptions(); - setupSocket(options, bind_to_port); -} - UdsListenSocket::UdsListenSocket(const Address::InstanceConstSharedPtr& address) : ListenSocketImpl(ioHandleForAddr(Socket::Type::Stream, address), address) { RELEASE_ASSERT(io_handle_->isOpen(), ""); diff --git a/source/common/network/listen_socket_impl.h b/source/common/network/listen_socket_impl.h index 65e09b2b461ae..a237e2d5b7e4d 100644 --- a/source/common/network/listen_socket_impl.h +++ b/source/common/network/listen_socket_impl.h @@ -80,10 +80,32 @@ template class NetworkListenSocket : public ListenSocketImpl { Socket::Type socketType() const override { return T::type; } + // These four overrides are introduced to perform check. A null io handle is possible only if the + // the owner socket is a listen socket that does not bind to port. + IoHandle& ioHandle() override { + ASSERT(io_handle_ != nullptr); + return *io_handle_; + } + const IoHandle& ioHandle() const override { + ASSERT(io_handle_ != nullptr); + return *io_handle_; + } + void close() override { + if (io_handle_ != nullptr) { + if (io_handle_->isOpen()) { + io_handle_->close(); + } + } + } + bool isOpen() const override { + ASSERT(io_handle_ != nullptr); + return io_handle_->isOpen(); + } + protected: void setPrebindSocketOptions() { // On Windows, SO_REUSEADDR does not restrict subsequent bind calls when there is a listener as - // on Linux and later BSD socket stacks + // on Linux and later BSD socket stacks. #ifndef WIN32 int on = 1; auto status = setSocketOption(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); @@ -96,12 +118,6 @@ template <> inline void NetworkListenSocket>::setPrebindSocketOptions() {} -// UDP listen socket desires io handle regardless bind_to_port is true or false. -template <> -NetworkListenSocket>::NetworkListenSocket( - const Address::InstanceConstSharedPtr& address, - const Network::Socket::OptionsSharedPtr& options, bool bind_to_port); - template class NetworkListenSocket>; template class NetworkListenSocket>; diff --git a/source/server/listener_impl.h b/source/server/listener_impl.h index 1fa60494715f3..42382a785964b 100644 --- a/source/server/listener_impl.h +++ b/source/server/listener_impl.h @@ -54,14 +54,18 @@ class ListenSocketFactoryImpl : public Network::ListenSocketFactory, Network::SocketSharedPtr getListenSocket() override; /** - * @return the socket shared by worker threads; otherwise return null. + * @return the socket shared by worker threads; otherwise return nullopt. */ Network::SocketOptRef sharedSocket() const override { + // If listen socket doesn't bind to port, consider it not shared. + if (!bind_to_port_) { + return absl::nullopt; + } if (!reuse_port_) { ASSERT(socket_ != nullptr); return *socket_; } - // If reuse_port is true, always return null, even socket_ is created for reserving + // If reuse_port is true, always return nullopt, even socket_ is created for reserving // port number. return absl::nullopt; } diff --git a/test/common/network/listen_socket_impl_test.cc b/test/common/network/listen_socket_impl_test.cc index 7e787507020fe..5c78af7a2e2c5 100644 --- a/test/common/network/listen_socket_impl_test.cc +++ b/test/common/network/listen_socket_impl_test.cc @@ -8,7 +8,6 @@ #include "source/common/api/os_sys_calls_impl.h" #include "source/common/network/io_socket_handle_impl.h" #include "source/common/network/listen_socket_impl.h" -#include "source/common/network/socket_interface_impl.h" #include "source/common/network/utility.h" #include "test/mocks/network/mocks.h" @@ -25,18 +24,6 @@ namespace Envoy { namespace Network { namespace { -class MockSingleFamilySocketInterface : public SocketInterfaceImpl { -public: - explicit MockSingleFamilySocketInterface(Address::IpVersion version) : version_(version) {} - MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, Address::Type, Address::IpVersion, bool), - (const)); - MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, const Address::InstanceConstSharedPtr), (const)); - bool ipFamilySupported(int domain) override { - return (version_ == Address::IpVersion::v4) ? domain == AF_INET : domain == AF_INET6; - } - const Address::IpVersion version_; -}; - TEST(ConnectionSocketImplTest, LowerCaseRequestedServerName) { absl::string_view serverName("www.EXAMPLE.com"); absl::string_view expectedServerName("www.example.com"); @@ -198,7 +185,8 @@ TEST_P(ListenSocketImplTestTcp, CheckIpVersionWithNullLocalAddress) { } TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsdSocketCreated) { - auto mock_interface = std::make_unique(version_); + auto mock_interface = + std::make_unique(std::vector{version_}); auto* mock_interface_ptr = mock_interface.get(); auto any_address = version_ == Address::IpVersion::v4 ? Utility::getIpv4AnyAddress() : Utility::getIpv6AnyAddress(); @@ -214,7 +202,8 @@ TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsd } TEST_P(ListenSocketImplTestTcp, DeathAtUnSupportedIpFamilyListenSocket) { - auto mock_interface = std::make_unique(version_); + auto mock_interface = + std::make_unique(std::vector{version_}); auto* mock_interface_ptr = mock_interface.get(); auto the_other_address = version_ == Address::IpVersion::v4 ? Utility::getIpv6AnyAddress() : Utility::getIpv4AnyAddress(); diff --git a/test/common/quic/active_quic_listener_test.cc b/test/common/quic/active_quic_listener_test.cc index 1dbd90bb6009b..edc1bb8fee0c7 100644 --- a/test/common/quic/active_quic_listener_test.cc +++ b/test/common/quic/active_quic_listener_test.cc @@ -75,9 +75,6 @@ class ActiveQuicListenerFactoryPeer { class ActiveQuicListenerTest : public QuicMultiVersionTest { protected: - using Socket = - Network::NetworkListenSocket>; - ActiveQuicListenerTest() : version_(GetParam().first), api_(Api::createApiForTest(simulated_time_system_)), dispatcher_(api_->allocateDispatcher("test_thread")), clock_(*dispatcher_), @@ -208,7 +205,8 @@ class ActiveQuicListenerTest : public QuicMultiVersionTest { } void sendCHLO(quic::QuicConnectionId connection_id) { - client_sockets_.push_back(std::make_unique(local_address_, nullptr, /*bind*/ false)); + client_sockets_.push_back(std::make_unique(Network::Socket::Type::Datagram, + local_address_, nullptr)); Buffer::OwnedImpl payload = generateChloPacketToSend( quic_version_, quic_config_, ActiveQuicListenerPeer::cryptoConfig(*quic_listener_), connection_id, clock_, envoyIpAddressToQuicSocketAddress(local_address_->ip()), @@ -317,7 +315,7 @@ class ActiveQuicListenerTest : public QuicMultiVersionTest { Init::MockManager init_manager_; NiceMock validation_visitor_; - std::list> client_sockets_; + std::list> client_sockets_; std::list> read_filters_; Network::MockFilterChainManager filter_chain_manager_; // The following two containers must guarantee pointer stability as addresses diff --git a/test/mocks/network/BUILD b/test/mocks/network/BUILD index 8bc9532c114d4..38cd2003b09d7 100644 --- a/test/mocks/network/BUILD +++ b/test/mocks/network/BUILD @@ -56,6 +56,7 @@ envoy_cc_mock( "//envoy/network:transport_socket_interface", "//envoy/server:listener_manager_interface", "//source/common/network:address_lib", + "//source/common/network:socket_interface_lib", "//source/common/network:utility_lib", "//source/common/stats:isolated_store_lib", "//test/mocks/event:event_mocks", diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index fb4f106b73046..b187695038725 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -19,6 +20,7 @@ #include "source/common/network/filter_manager_impl.h" #include "source/common/network/socket_interface.h" +#include "source/common/network/socket_interface_impl.h" #include "source/common/stats/isolated_store_impl.h" #include "test/mocks/event/mocks.h" @@ -600,5 +602,20 @@ class MockUdpPacketProcessor : public UdpPacketProcessor { MOCK_METHOD(size_t, numPacketsExpectedPerEventLoop, (), (const)); }; +class MockSocketInterface : public SocketInterfaceImpl { +public: + explicit MockSocketInterface(const std::vector& versions) + : versions_(versions.begin(), versions.end()) {} + MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, Address::Type, Address::IpVersion, bool), + (const)); + MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, const Address::InstanceConstSharedPtr), (const)); + bool ipFamilySupported(int domain) override { + const auto to_version = domain == AF_INET ? Address::IpVersion::v4 : Address::IpVersion::v6; + return std::any_of(versions_.begin(), versions_.end(), + [to_version](auto version) { return to_version == version; }); + } + const std::vector versions_; +}; + } // namespace Network } // namespace Envoy diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index 489bca559ace7..d72546c3dba8d 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -20,6 +20,7 @@ #include "source/common/init/manager_impl.h" #include "source/common/network/address_impl.h" #include "source/common/network/io_socket_handle_impl.h" +#include "source/common/network/socket_interface_impl.h" #include "source/common/network/utility.h" #include "source/common/protobuf/protobuf.h" #include "source/extensions/filters/listener/original_dst/original_dst.h" @@ -1630,6 +1631,10 @@ name: foo TEST_F(ListenerManagerImplTest, BindToPortEqualToFalse) { InSequence s; + auto mock_interface = std::make_unique( + std::vector{Network::Address::IpVersion::v4}); + StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + ProdListenerComponentFactory real_listener_factory(server_); EXPECT_CALL(*worker_, start(_, _)); manager_->startWorkers(guard_dog_, callback_.AsStdFunction()); @@ -1640,24 +1645,21 @@ name: foo address: 127.0.0.1 port_value: 1234 bind_to_port: false +reuse_port: false filter_chains: - filters: [] )EOF"; - auto syscall_result = os_sys_calls_actual_.socket(AF_INET, SOCK_STREAM, 0); - ASSERT_TRUE(SOCKET_VALID(syscall_result.rc_)); - ListenerHandle* listener_foo = expectListenerCreate(true, true); EXPECT_CALL(listener_factory_, createListenSocket(_, _, _, ListenSocketCreationParams(false))) - .WillOnce(Invoke([this, &syscall_result, &real_listener_factory]( + .WillOnce(Invoke([this, &real_listener_factory]( const Network::Address::InstanceConstSharedPtr& address, Network::Socket::Type socket_type, const Network::Socket::OptionsSharedPtr& options, const ListenSocketCreationParams& params) -> Network::SocketSharedPtr { EXPECT_CALL(server_, hotRestart).Times(0); - // When bind_to_port is equal to false, create socket fd directly, and do not get socket - // fd through hot restart. - ON_CALL(os_sys_calls_, socket(AF_INET, _, 0)).WillByDefault(Return(syscall_result)); + // When bind_to_port is equal to false, the BSD socket is not created at main thread. + EXPECT_CALL(os_sys_calls_, socket(AF_INET, _, 0)).Times(0); return real_listener_factory.createListenSocket(address, socket_type, options, params); })); EXPECT_CALL(listener_foo->target_, initialize()); @@ -1665,6 +1667,60 @@ bind_to_port: false EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromV3Yaml(listener_foo_yaml), "", true)); } +TEST_F(ListenerManagerImplTest, UpdateBindToPortEqualToFalse) { + InSequence s; + auto mock_interface = std::make_unique( + std::vector{Network::Address::IpVersion::v4}); + StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + + ProdListenerComponentFactory real_listener_factory(server_); + EXPECT_CALL(*worker_, start(_, _)); + manager_->startWorkers(guard_dog_, callback_.AsStdFunction()); + const std::string listener_foo_yaml = R"EOF( +name: foo +address: + socket_address: + address: 127.0.0.1 + port_value: 1234 +bind_to_port: false +reuse_port: false +filter_chains: +- filters: [] + )EOF"; + + ListenerHandle* listener_foo = expectListenerCreate(false, true); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, _, ListenSocketCreationParams(false))) + .WillOnce(Invoke([this, &real_listener_factory]( + const Network::Address::InstanceConstSharedPtr& address, + Network::Socket::Type socket_type, + const Network::Socket::OptionsSharedPtr& options, + const ListenSocketCreationParams& params) -> Network::SocketSharedPtr { + EXPECT_CALL(server_, hotRestart).Times(0); + // When bind_to_port is equal to false, the BSD socket is not created at main thread. + EXPECT_CALL(os_sys_calls_, socket(AF_INET, _, 0)).Times(0); + return real_listener_factory.createListenSocket(address, socket_type, options, params); + })); + EXPECT_CALL(*worker_, addListener(_, _, _)); + EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromV3Yaml(listener_foo_yaml), "", true)); + + worker_->callAddCompletion(true); + + EXPECT_CALL(*listener_foo->drain_manager_, drainClose()).WillOnce(Return(false)); + EXPECT_CALL(server_.drain_manager_, drainClose()).WillOnce(Return(false)); + EXPECT_FALSE(listener_foo->context_->drainDecision().drainClose()); + + EXPECT_CALL(*worker_, stopListener(_, _)); + EXPECT_CALL(*listener_foo->drain_manager_, startDrainSequence(_)); + + EXPECT_TRUE(manager_->removeListener("foo")); + + EXPECT_CALL(*worker_, removeListener(_, _)); + listener_foo->drain_manager_->drain_sequence_completion_(); + + EXPECT_CALL(*listener_foo, onDestroy()); + worker_->callRemovalCompletion(); +} + TEST_F(ListenerManagerImplTest, DEPRECATED_FEATURE_TEST(DeprecatedBindToPortEqualToFalse)) { InSequence s; ProdListenerComponentFactory real_listener_factory(server_);