Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions source/common/network/listen_socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,6 @@ void ListenSocketImpl::setupSocket(const Network::Socket::OptionsSharedPtr& opti
}
}

// UDP listen socket desires io handle regardless bind_to_port is true or false.
template <>
NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::NetworkListenSocket(
const Address::InstanceConstSharedPtr& address,
const Network::Socket::OptionsSharedPtr& options, bool bind_to_port)
: ListenSocketImpl(Network::ioHandleForAddr(Socket::Type::Datagram, address), address) {
setPrebindSocketOptions();
setupSocket(options, bind_to_port);
}

UdsListenSocket::UdsListenSocket(const Address::InstanceConstSharedPtr& address)
: ListenSocketImpl(ioHandleForAddr(Socket::Type::Stream, address), address) {
RELEASE_ASSERT(io_handle_->isOpen(), "");
Expand Down
30 changes: 23 additions & 7 deletions source/common/network/listen_socket_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,32 @@ template <typename T> class NetworkListenSocket : public ListenSocketImpl {

Socket::Type socketType() const override { return T::type; }

// These four overrides are introduced to perform check. A null io handle is possible only if the
// the owner socket is a listen socket that does not bind to port.
IoHandle& ioHandle() override {
ASSERT(io_handle_ != nullptr);
return *io_handle_;
}
const IoHandle& ioHandle() const override {
ASSERT(io_handle_ != nullptr);
return *io_handle_;
}
void close() override {
if (io_handle_ != nullptr) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In working through my other change, I'm hitting a similar issue on setSocketOption (now that reuse port is the default). I think there is a similar bug there also. It makes me think we are going about the no bind sockets incorrectly in general if we need all these guards.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(i.e. why do no bind listeners even have sockets. I realize this is a much larger change and I'm trying to get through my current mess first, but this seems like something we might want to fix in the future.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetSocketOption should check bind_to_port w/ or w/o PR. At least ListenSocketImpl::setupSocket has a bind_to_port arg.

why do no bind listeners even have sockets

Good point! I was thinking about create a derived not-bind-listen-socket and you think deeper

if (io_handle_->isOpen()) {
io_handle_->close();
}
}
}
bool isOpen() const override {
ASSERT(io_handle_ != nullptr);
return io_handle_->isOpen();
}

protected:
void setPrebindSocketOptions() {
// On Windows, SO_REUSEADDR does not restrict subsequent bind calls when there is a listener as
// on Linux and later BSD socket stacks
// on Linux and later BSD socket stacks.
#ifndef WIN32
int on = 1;
auto status = setSocketOption(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
Expand All @@ -96,12 +118,6 @@ template <>
inline void
NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::setPrebindSocketOptions() {}

// UDP listen socket desires io handle regardless bind_to_port is true or false.
template <>
NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::NetworkListenSocket(
const Address::InstanceConstSharedPtr& address,
const Network::Socket::OptionsSharedPtr& options, bool bind_to_port);

template class NetworkListenSocket<NetworkSocketTrait<Socket::Type::Stream>>;
template class NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>;

Expand Down
8 changes: 6 additions & 2 deletions source/server/listener_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,18 @@ class ListenSocketFactoryImpl : public Network::ListenSocketFactory,
Network::SocketSharedPtr getListenSocket() override;

/**
* @return the socket shared by worker threads; otherwise return null.
* @return the socket shared by worker threads; otherwise return nullopt.
*/
Network::SocketOptRef sharedSocket() const override {
// If listen socket doesn't bind to port, consider it not shared.
if (!bind_to_port_) {
return absl::nullopt;
}
if (!reuse_port_) {
ASSERT(socket_ != nullptr);
return *socket_;
}
// If reuse_port is true, always return null, even socket_ is created for reserving
// If reuse_port is true, always return nullopt, even socket_ is created for reserving
// port number.
return absl::nullopt;
}
Expand Down
19 changes: 4 additions & 15 deletions test/common/network/listen_socket_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "source/common/api/os_sys_calls_impl.h"
#include "source/common/network/io_socket_handle_impl.h"
#include "source/common/network/listen_socket_impl.h"
#include "source/common/network/socket_interface_impl.h"
#include "source/common/network/utility.h"

#include "test/mocks/network/mocks.h"
Expand All @@ -25,18 +24,6 @@ namespace Envoy {
namespace Network {
namespace {

class MockSingleFamilySocketInterface : public SocketInterfaceImpl {
public:
explicit MockSingleFamilySocketInterface(Address::IpVersion version) : version_(version) {}
MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, Address::Type, Address::IpVersion, bool),
(const));
MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, const Address::InstanceConstSharedPtr), (const));
bool ipFamilySupported(int domain) override {
return (version_ == Address::IpVersion::v4) ? domain == AF_INET : domain == AF_INET6;
}
const Address::IpVersion version_;
};

TEST(ConnectionSocketImplTest, LowerCaseRequestedServerName) {
absl::string_view serverName("www.EXAMPLE.com");
absl::string_view expectedServerName("www.example.com");
Expand Down Expand Up @@ -198,7 +185,8 @@ TEST_P(ListenSocketImplTestTcp, CheckIpVersionWithNullLocalAddress) {
}

TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsdSocketCreated) {
auto mock_interface = std::make_unique<MockSingleFamilySocketInterface>(version_);
auto mock_interface =
std::make_unique<MockSocketInterface>(std::vector<Network::Address::IpVersion>{version_});
auto* mock_interface_ptr = mock_interface.get();
auto any_address = version_ == Address::IpVersion::v4 ? Utility::getIpv4AnyAddress()
: Utility::getIpv6AnyAddress();
Expand All @@ -214,7 +202,8 @@ TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsd
}

TEST_P(ListenSocketImplTestTcp, DeathAtUnSupportedIpFamilyListenSocket) {
auto mock_interface = std::make_unique<MockSingleFamilySocketInterface>(version_);
auto mock_interface =
std::make_unique<MockSocketInterface>(std::vector<Network::Address::IpVersion>{version_});
auto* mock_interface_ptr = mock_interface.get();
auto the_other_address = version_ == Address::IpVersion::v4 ? Utility::getIpv6AnyAddress()
: Utility::getIpv4AnyAddress();
Expand Down
8 changes: 3 additions & 5 deletions test/common/quic/active_quic_listener_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ class ActiveQuicListenerFactoryPeer {

class ActiveQuicListenerTest : public QuicMultiVersionTest {
protected:
using Socket =
Network::NetworkListenSocket<Network::NetworkSocketTrait<Network::Socket::Type::Datagram>>;

ActiveQuicListenerTest()
: version_(GetParam().first), api_(Api::createApiForTest(simulated_time_system_)),
dispatcher_(api_->allocateDispatcher("test_thread")), clock_(*dispatcher_),
Expand Down Expand Up @@ -208,7 +205,8 @@ class ActiveQuicListenerTest : public QuicMultiVersionTest {
}

void sendCHLO(quic::QuicConnectionId connection_id) {
client_sockets_.push_back(std::make_unique<Socket>(local_address_, nullptr, /*bind*/ false));
client_sockets_.push_back(std::make_unique<Network::SocketImpl>(Network::Socket::Type::Datagram,
Comment on lines 211 to +208
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattklein123 With this change, the io_handle == nullptr iff bind_to_port for listensocket, both tcp and udp.

Now UDP and TCP listener behavior is now unified including what you pointed out

local_address_, nullptr));
Buffer::OwnedImpl payload = generateChloPacketToSend(
quic_version_, quic_config_, ActiveQuicListenerPeer::cryptoConfig(*quic_listener_),
connection_id, clock_, envoyIpAddressToQuicSocketAddress(local_address_->ip()),
Expand Down Expand Up @@ -317,7 +315,7 @@ class ActiveQuicListenerTest : public QuicMultiVersionTest {
Init::MockManager init_manager_;
NiceMock<ProtobufMessage::MockValidationVisitor> validation_visitor_;

std::list<std::unique_ptr<Socket>> client_sockets_;
std::list<std::unique_ptr<Network::SocketImpl>> client_sockets_;
std::list<std::shared_ptr<Network::MockReadFilter>> read_filters_;
Network::MockFilterChainManager filter_chain_manager_;
// The following two containers must guarantee pointer stability as addresses
Expand Down
1 change: 1 addition & 0 deletions test/mocks/network/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ envoy_cc_mock(
"//envoy/network:transport_socket_interface",
"//envoy/server:listener_manager_interface",
"//source/common/network:address_lib",
"//source/common/network:socket_interface_lib",
"//source/common/network:utility_lib",
"//source/common/stats:isolated_store_lib",
"//test/mocks/event:event_mocks",
Expand Down
17 changes: 17 additions & 0 deletions test/mocks/network/mocks.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <algorithm>
#include <cstdint>
#include <list>
#include <ostream>
Expand All @@ -19,6 +20,7 @@

#include "source/common/network/filter_manager_impl.h"
#include "source/common/network/socket_interface.h"
#include "source/common/network/socket_interface_impl.h"
#include "source/common/stats/isolated_store_impl.h"

#include "test/mocks/event/mocks.h"
Expand Down Expand Up @@ -600,5 +602,20 @@ class MockUdpPacketProcessor : public UdpPacketProcessor {
MOCK_METHOD(size_t, numPacketsExpectedPerEventLoop, (), (const));
};

class MockSocketInterface : public SocketInterfaceImpl {
public:
explicit MockSocketInterface(const std::vector<Address::IpVersion>& versions)
: versions_(versions.begin(), versions.end()) {}
MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, Address::Type, Address::IpVersion, bool),
(const));
MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, const Address::InstanceConstSharedPtr), (const));
bool ipFamilySupported(int domain) override {
const auto to_version = domain == AF_INET ? Address::IpVersion::v4 : Address::IpVersion::v6;
return std::any_of(versions_.begin(), versions_.end(),
[to_version](auto version) { return to_version == version; });
}
const std::vector<Address::IpVersion> versions_;
};

} // namespace Network
} // namespace Envoy
70 changes: 63 additions & 7 deletions test/server/listener_manager_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "source/common/init/manager_impl.h"
#include "source/common/network/address_impl.h"
#include "source/common/network/io_socket_handle_impl.h"
#include "source/common/network/socket_interface_impl.h"
#include "source/common/network/utility.h"
#include "source/common/protobuf/protobuf.h"
#include "source/extensions/filters/listener/original_dst/original_dst.h"
Expand Down Expand Up @@ -1630,6 +1631,10 @@ name: foo

TEST_F(ListenerManagerImplTest, BindToPortEqualToFalse) {
InSequence s;
auto mock_interface = std::make_unique<Network::MockSocketInterface>(
std::vector<Network::Address::IpVersion>{Network::Address::IpVersion::v4});
StackedScopedInjectableLoader<Network::SocketInterface> new_interface(std::move(mock_interface));

ProdListenerComponentFactory real_listener_factory(server_);
EXPECT_CALL(*worker_, start(_, _));
manager_->startWorkers(guard_dog_, callback_.AsStdFunction());
Expand All @@ -1640,31 +1645,82 @@ name: foo
address: 127.0.0.1
port_value: 1234
bind_to_port: false
reuse_port: false
filter_chains:
- filters: []
)EOF";

auto syscall_result = os_sys_calls_actual_.socket(AF_INET, SOCK_STREAM, 0);
ASSERT_TRUE(SOCKET_VALID(syscall_result.rc_));

ListenerHandle* listener_foo = expectListenerCreate(true, true);
EXPECT_CALL(listener_factory_, createListenSocket(_, _, _, ListenSocketCreationParams(false)))
.WillOnce(Invoke([this, &syscall_result, &real_listener_factory](
.WillOnce(Invoke([this, &real_listener_factory](
const Network::Address::InstanceConstSharedPtr& address,
Network::Socket::Type socket_type,
const Network::Socket::OptionsSharedPtr& options,
const ListenSocketCreationParams& params) -> Network::SocketSharedPtr {
EXPECT_CALL(server_, hotRestart).Times(0);
// When bind_to_port is equal to false, create socket fd directly, and do not get socket
// fd through hot restart.
ON_CALL(os_sys_calls_, socket(AF_INET, _, 0)).WillByDefault(Return(syscall_result));
// When bind_to_port is equal to false, the BSD socket is not created at main thread.
EXPECT_CALL(os_sys_calls_, socket(AF_INET, _, 0)).Times(0);
return real_listener_factory.createListenSocket(address, socket_type, options, params);
}));
EXPECT_CALL(listener_foo->target_, initialize());
EXPECT_CALL(*listener_foo, onDestroy());
EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromV3Yaml(listener_foo_yaml), "", true));
}

TEST_F(ListenerManagerImplTest, UpdateBindToPortEqualToFalse) {
InSequence s;
auto mock_interface = std::make_unique<Network::MockSocketInterface>(
std::vector<Network::Address::IpVersion>{Network::Address::IpVersion::v4});
StackedScopedInjectableLoader<Network::SocketInterface> new_interface(std::move(mock_interface));

ProdListenerComponentFactory real_listener_factory(server_);
EXPECT_CALL(*worker_, start(_, _));
manager_->startWorkers(guard_dog_, callback_.AsStdFunction());
const std::string listener_foo_yaml = R"EOF(
name: foo
address:
socket_address:
address: 127.0.0.1
port_value: 1234
bind_to_port: false
reuse_port: false
filter_chains:
- filters: []
)EOF";

ListenerHandle* listener_foo = expectListenerCreate(false, true);
EXPECT_CALL(listener_factory_, createListenSocket(_, _, _, ListenSocketCreationParams(false)))
.WillOnce(Invoke([this, &real_listener_factory](
const Network::Address::InstanceConstSharedPtr& address,
Network::Socket::Type socket_type,
const Network::Socket::OptionsSharedPtr& options,
const ListenSocketCreationParams& params) -> Network::SocketSharedPtr {
EXPECT_CALL(server_, hotRestart).Times(0);
// When bind_to_port is equal to false, the BSD socket is not created at main thread.
EXPECT_CALL(os_sys_calls_, socket(AF_INET, _, 0)).Times(0);
return real_listener_factory.createListenSocket(address, socket_type, options, params);
}));
EXPECT_CALL(*worker_, addListener(_, _, _));
EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromV3Yaml(listener_foo_yaml), "", true));

worker_->callAddCompletion(true);

EXPECT_CALL(*listener_foo->drain_manager_, drainClose()).WillOnce(Return(false));
EXPECT_CALL(server_.drain_manager_, drainClose()).WillOnce(Return(false));
EXPECT_FALSE(listener_foo->context_->drainDecision().drainClose());

EXPECT_CALL(*worker_, stopListener(_, _));
EXPECT_CALL(*listener_foo->drain_manager_, startDrainSequence(_));

EXPECT_TRUE(manager_->removeListener("foo"));

EXPECT_CALL(*worker_, removeListener(_, _));
listener_foo->drain_manager_->drain_sequence_completion_();

EXPECT_CALL(*listener_foo, onDestroy());
worker_->callRemovalCompletion();
}

TEST_F(ListenerManagerImplTest, DEPRECATED_FEATURE_TEST(DeprecatedBindToPortEqualToFalse)) {
InSequence s;
ProdListenerComponentFactory real_listener_factory(server_);
Expand Down