Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions source/common/network/listen_socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,16 @@ void ListenSocketImpl::setupSocket(const Network::Socket::OptionsSharedPtr& opti
}
}

// UDP listen socket desires io handle regardless bind_to_port is true or false.
template <>
void NetworkListenSocket<NetworkSocketTrait<Socket::Type::Stream>>::setPrebindSocketOptions() {
// On Windows, SO_REUSEADDR does not restrict subsequent bind calls when there is a listener as on
// Linux and later BSD socket stacks
#ifndef WIN32
int on = 1;
auto status = setSocketOption(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
RELEASE_ASSERT(status.rc_ != -1, "failed to set SO_REUSEADDR socket option");
#endif
NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::NetworkListenSocket(
const Address::InstanceConstSharedPtr& address,
const Network::Socket::OptionsSharedPtr& options, bool bind_to_port)
: ListenSocketImpl(Network::ioHandleForAddr(Socket::Type::Datagram, address), address) {
setPrebindSocketOptions();
setupSocket(options, bind_to_port);
}

template <>
void NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::setPrebindSocketOptions() {}

UdsListenSocket::UdsListenSocket(const Address::InstanceConstSharedPtr& address)
: ListenSocketImpl(ioHandleForAddr(Socket::Type::Stream, address), address) {
RELEASE_ASSERT(io_handle_->isOpen(), "");
Expand Down
49 changes: 42 additions & 7 deletions source/common/network/listen_socket_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "common/common/assert.h"
#include "common/common/dump_state_utils.h"
#include "common/network/socket_impl.h"
#include "common/network/socket_interface.h"

namespace Envoy {
namespace Network {
Expand All @@ -25,7 +26,8 @@ class ListenSocketImpl : public SocketImpl {
SocketPtr duplicate() override {
// Using `new` to access a non-public constructor.
return absl::WrapUnique(
new ListenSocketImpl(io_handle_->duplicate(), address_provider_->localAddress()));
new ListenSocketImpl(io_handle_ == nullptr ? nullptr : io_handle_->duplicate(),
address_provider_->localAddress()));
}

void setupSocket(const Network::Socket::OptionsSharedPtr& options, bool bind_to_port);
Expand All @@ -50,11 +52,23 @@ template <typename T> class NetworkListenSocket : public ListenSocketImpl {
public:
NetworkListenSocket(const Address::InstanceConstSharedPtr& address,
const Network::Socket::OptionsSharedPtr& options, bool bind_to_port)
: ListenSocketImpl(Network::ioHandleForAddr(T::type, address), address) {
RELEASE_ASSERT(io_handle_->isOpen(), "");

setPrebindSocketOptions();

: ListenSocketImpl(bind_to_port ? Network::ioHandleForAddr(T::type, address) : nullptr,
address) {
// Prebind is applied if the socket is bind to port.
if (bind_to_port) {
RELEASE_ASSERT(io_handle_->isOpen(), "");
setPrebindSocketOptions();
} else {
// If the tcp listener does not bind to port, we test that the ip family is supported.
if (auto ip = address->ip(); ip != nullptr) {
RELEASE_ASSERT(
Network::SocketInterfaceSingleton::get().ipFamilySupported(ip->ipv4() ? AF_INET
: AF_INET6),
fmt::format(
"Creating listen socket address {} but the address familiy is not supported",
address->asStringView()));
}
}
setupSocket(options, bind_to_port);
}

Expand All @@ -67,9 +81,30 @@ template <typename T> class NetworkListenSocket : public ListenSocketImpl {
Socket::Type socketType() const override { return T::type; }

protected:
void setPrebindSocketOptions();
void setPrebindSocketOptions() {
// On Windows, SO_REUSEADDR does not restrict subsequent bind calls when there is a listener as
// on Linux and later BSD socket stacks
#ifndef WIN32
int on = 1;
auto status = setSocketOption(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
RELEASE_ASSERT(status.rc_ != -1, "failed to set SO_REUSEADDR socket option");
#endif
}
};

template <>
inline void
NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::setPrebindSocketOptions() {}

// UDP listen socket desires io handle regardless bind_to_port is true or false.
template <>
NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>::NetworkListenSocket(
const Address::InstanceConstSharedPtr& address,
const Network::Socket::OptionsSharedPtr& options, bool bind_to_port);

template class NetworkListenSocket<NetworkSocketTrait<Socket::Type::Stream>>;
template class NetworkListenSocket<NetworkSocketTrait<Socket::Type::Datagram>>;

using TcpListenSocket = NetworkListenSocket<NetworkSocketTrait<Socket::Type::Stream>>;
using TcpListenSocketPtr = std::unique_ptr<TcpListenSocket>;

Expand Down
21 changes: 18 additions & 3 deletions source/common/network/tcp_listener_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ bool TcpListenerImpl::rejectCxOverGlobalLimit() {
}

void TcpListenerImpl::onSocketEvent(short flags) {
ASSERT(bind_to_port_);
ASSERT(flags & (Event::FileReadyType::Read));

// TODO(fcoras): Add limit on number of accepted calls per wakeup
Expand Down Expand Up @@ -95,6 +96,8 @@ void TcpListenerImpl::onSocketEvent(short flags) {
}

void TcpListenerImpl::setupServerSocket(Event::DispatcherImpl& dispatcher, Socket& socket) {
ASSERT(bind_to_port_);

socket.ioHandle().listen(backlog_size_);

// Although onSocketEvent drains to completion, use level triggered mode to avoid potential
Expand All @@ -114,15 +117,27 @@ TcpListenerImpl::TcpListenerImpl(Event::DispatcherImpl& dispatcher, Random::Rand
SocketSharedPtr socket, TcpListenerCallbacks& cb,
bool bind_to_port, uint32_t backlog_size)
: BaseListenerImpl(dispatcher, std::move(socket)), cb_(cb), backlog_size_(backlog_size),
random_(random), reject_fraction_(0.0) {
random_(random), bind_to_port_(bind_to_port), reject_fraction_(0.0) {
if (bind_to_port) {
setupServerSocket(dispatcher, *socket_);
}
}

void TcpListenerImpl::enable() { socket_->ioHandle().enableFileEvents(Event::FileReadyType::Read); }
void TcpListenerImpl::enable() {
if (bind_to_port_) {
socket_->ioHandle().enableFileEvents(Event::FileReadyType::Read);
} else {
FANCY_LOG(debug, "The listener cannot be enabled since it's not bind to port.");
}
}

void TcpListenerImpl::disable() { socket_->ioHandle().enableFileEvents(0); }
void TcpListenerImpl::disable() {
if (bind_to_port_) {
socket_->ioHandle().enableFileEvents(0);
} else {
FANCY_LOG(debug, "The listener cannot be disable since it's not bind to port.");
}
}

void TcpListenerImpl::setRejectFraction(const UnitFloat reject_fraction) {
reject_fraction_ = reject_fraction;
Expand Down
7 changes: 6 additions & 1 deletion source/common/network/tcp_listener_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ class TcpListenerImpl : public BaseListenerImpl {
TcpListenerImpl(Event::DispatcherImpl& dispatcher, Random::RandomGenerator& random,
SocketSharedPtr socket, TcpListenerCallbacks& cb, bool bind_to_port,
uint32_t backlog_size);
~TcpListenerImpl() override { socket_->ioHandle().resetFileEvents(); }
~TcpListenerImpl() override {
if (bind_to_port_) {
socket_->ioHandle().resetFileEvents();
}
}
void disable() override;
void enable() override;
void setRejectFraction(UnitFloat reject_fraction) override;
Expand All @@ -40,6 +44,7 @@ class TcpListenerImpl : public BaseListenerImpl {
static bool rejectCxOverGlobalLimit();

Random::RandomGenerator& random_;
bool bind_to_port_;
UnitFloat reject_fraction_;
};

Expand Down
22 changes: 21 additions & 1 deletion source/common/singleton/threadsafe_singleton.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ template <class T> T* InjectableSingleton<T>::loader_ = nullptr;

template <class T> class ScopedInjectableLoader {
public:
ScopedInjectableLoader(std::unique_ptr<T>&& instance) {
explicit ScopedInjectableLoader(std::unique_ptr<T>&& instance) {
instance_ = std::move(instance);
InjectableSingleton<T>::initialize(instance_.get());
}
Expand All @@ -84,4 +84,24 @@ template <class T> class ScopedInjectableLoader {
std::unique_ptr<T> instance_;
};

// This class saves the singleton object and restore the original singleton at destroy. This class
// is not thread safe. It can be used in single thread test.
template <class T>
class StackedScopedInjectableLoader :
// To access the protected loader_.
protected InjectableSingleton<T> {
public:
explicit StackedScopedInjectableLoader(std::unique_ptr<T>&& instance) {
original_loader_ = InjectableSingleton<T>::getExisting();
InjectableSingleton<T>::clear();
instance_ = std::move(instance);
InjectableSingleton<T>::initialize(instance_.get());
}
~StackedScopedInjectableLoader() { InjectableSingleton<T>::loader_ = original_loader_; }

private:
std::unique_ptr<T> instance_;
T* original_loader_;
};

} // namespace Envoy
49 changes: 49 additions & 0 deletions test/common/network/listen_socket_impl_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <memory>

#include "envoy/common/platform.h"
#include "envoy/config/core/v3/base.pb.h"
#include "envoy/network/exception.h"
Expand All @@ -6,6 +8,7 @@
#include "common/api/os_sys_calls_impl.h"
#include "common/network/io_socket_handle_impl.h"
#include "common/network/listen_socket_impl.h"
#include "common/network/socket_interface_impl.h"
#include "common/network/utility.h"

#include "test/mocks/network/mocks.h"
Expand All @@ -22,6 +25,18 @@ namespace Envoy {
namespace Network {
namespace {

class MockSingleFamilySocketInterface : public SocketInterfaceImpl {
public:
explicit MockSingleFamilySocketInterface(Address::IpVersion version) : version_(version) {}
MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, Address::Type, Address::IpVersion, bool),
(const));
MOCK_METHOD(IoHandlePtr, socket, (Socket::Type, const Address::InstanceConstSharedPtr), (const));
bool ipFamilySupported(int domain) override {
return (version_ == Address::IpVersion::v4) ? domain == AF_INET : domain == AF_INET6;
}
const Address::IpVersion version_;
};

TEST(ConnectionSocketImplTest, LowerCaseRequestedServerName) {
absl::string_view serverName("www.EXAMPLE.com");
absl::string_view expectedServerName("www.example.com");
Expand Down Expand Up @@ -182,6 +197,40 @@ TEST_P(ListenSocketImplTestTcp, CheckIpVersionWithNullLocalAddress) {
EXPECT_EQ(Address::IpVersion::v4, socket.ipVersion());
}

TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsdSocketCreated) {
auto mock_interface = std::make_unique<MockSingleFamilySocketInterface>(version_);
auto* mock_interface_ptr = mock_interface.get();
auto any_address = version_ == Address::IpVersion::v4 ? Utility::getIpv4AnyAddress()
: Utility::getIpv6AnyAddress();

StackedScopedInjectableLoader<SocketInterface> new_interface(std::move(mock_interface));

{
EXPECT_CALL(*mock_interface_ptr, socket(_, _)).Times(0);
EXPECT_CALL(*mock_interface_ptr, socket(_, _, _, _)).Times(0);
TcpListenSocket virtual_listener_socket(any_address, nullptr,
/*bind_to_port*/ false);
}
}

TEST_P(ListenSocketImplTestTcp, DeathAtUnSupportedIpFamilyListenSocket) {
auto mock_interface = std::make_unique<MockSingleFamilySocketInterface>(version_);
auto* mock_interface_ptr = mock_interface.get();
auto the_other_address = version_ == Address::IpVersion::v4 ? Utility::getIpv6AnyAddress()
: Utility::getIpv4AnyAddress();
StackedScopedInjectableLoader<SocketInterface> new_interface(std::move(mock_interface));
{
EXPECT_CALL(*mock_interface_ptr, socket(_, _)).Times(0);
EXPECT_CALL(*mock_interface_ptr, socket(_, _, _, _)).Times(0);
EXPECT_DEATH(
{
TcpListenSocket virtual_listener_socket(the_other_address, nullptr,
/*bind_to_port*/ false);
},
".*");
}
}

TEST_P(ListenSocketImplTestUdp, BindSpecificPort) { testBindSpecificPort(); }

// Validate that we get port allocation when binding to port zero.
Expand Down