Skip to content
5 changes: 1 addition & 4 deletions include/envoy/network/io_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ class IoHandle {
virtual Api::SysCallIntResult setBlocking(bool blocking) PURE;

/**
* Get domain used by underlying socket (see man 2 socket)
* @param domain updated to the underlying socket's domain if call is successful
* @return a Api::SysCallIntResult with rc_ = 0 for success and rc_ = -1 for failure. If the call
* is successful, errno_ shouldn't be used.
* @return the domain used by underlying socket (see man 2 socket)
*/
virtual absl::optional<int> domain() PURE;

Expand Down
3 changes: 2 additions & 1 deletion source/common/network/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ void ConnectionImpl::noDelay(bool enable) {
}
#endif

RELEASE_ASSERT(result.rc_ == 0, "");
RELEASE_ASSERT(result.rc_ == 0, fmt::format("Failed to set TCP_NODELAY with error {}, {}",
result.errno_, errorDetails(result.errno_)));
}

void ConnectionImpl::onRead(uint64_t read_buffer_size) {
Expand Down
83 changes: 45 additions & 38 deletions source/common/network/io_socket_handle_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,42 @@ using Envoy::Api::SysCallIntResult;
using Envoy::Api::SysCallSizeResult;

namespace Envoy {

namespace {
/**
* On different platforms the sockaddr struct for unix domain
* sockets is different. We use this function to get the
* length of the platform specific struct.
*/
constexpr socklen_t udsAddressLength() {
#if defined(__APPLE__)
return sizeof(sockaddr);
#elif defined(WIN32)
return sizeof(sockaddr_un);
#else
return sizeof(sa_family_t);
#endif
}

constexpr int messageTypeContainsIP() {
#ifdef IP_RECVDSTADDR
return IP_RECVDSTADDR;
#else
return IP_PKTINFO;
#endif
}

in_addr addressFromMessage(const cmsghdr& cmsg) {
#ifdef IP_RECVDSTADDR
return *reinterpret_cast<const in_addr*>(CMSG_DATA(&cmsg));
#else
auto info = reinterpret_cast<const in_pktinfo*>(CMSG_DATA(&cmsg));
return info->ipi_addr;
#endif
}

} // namespace

namespace Network {

IoSocketHandleImpl::~IoSocketHandleImpl() {
Expand Down Expand Up @@ -175,37 +211,25 @@ Address::InstanceConstSharedPtr maybeGetDstAddressFromHeader(const cmsghdr& cmsg
ipv6_addr->sin6_port = htons(self_port);
return getAddressFromSockAddrOrDie(ss, sizeof(sockaddr_in6), fd);
}
#ifndef IP_RECVDSTADDR
if (cmsg.cmsg_type == IP_PKTINFO) {
auto info = reinterpret_cast<const in_pktinfo*>(CMSG_DATA(&cmsg));
#else
if (cmsg.cmsg_type == IP_RECVDSTADDR) {
auto addr = reinterpret_cast<const in_addr*>(CMSG_DATA(&cmsg));
#endif

if (cmsg.cmsg_type == messageTypeContainsIP()) {
sockaddr_storage ss;
auto ipv4_addr = reinterpret_cast<sockaddr_in*>(&ss);
memset(ipv4_addr, 0, sizeof(sockaddr_in));
ipv4_addr->sin_family = AF_INET;
ipv4_addr->sin_addr =
#ifndef IP_RECVDSTADDR
info->ipi_addr;
#else
*addr;
#endif
ipv4_addr->sin_addr = addressFromMessage(cmsg);
ipv4_addr->sin_port = htons(self_port);
return getAddressFromSockAddrOrDie(ss, sizeof(sockaddr_in), fd);
}

return nullptr;
}

absl::optional<uint32_t> maybeGetPacketsDroppedFromHeader(
absl::optional<uint32_t> maybeGetPacketsDroppedFromHeader([[maybe_unused]] const cmsghdr& cmsg) {
#ifdef SO_RXQ_OVFL
const cmsghdr& cmsg) {
if (cmsg.cmsg_type == SO_RXQ_OVFL) {
return *reinterpret_cast<const uint32_t*>(CMSG_DATA(&cmsg));
}
#else
const cmsghdr&) {
#endif
return absl::nullopt;
}
Expand Down Expand Up @@ -404,7 +428,7 @@ IoHandlePtr IoSocketHandleImpl::accept(struct sockaddr* addr, socklen_t* addrlen
return nullptr;
}

return std::make_unique<IoSocketHandleImpl>(result.rc_, socket_v6only_);
return std::make_unique<IoSocketHandleImpl>(result.rc_, socket_v6only_, domain_);
}

Api::SysCallIntResult IoSocketHandleImpl::connect(Address::InstanceConstSharedPtr address) {
Expand All @@ -425,20 +449,7 @@ Api::SysCallIntResult IoSocketHandleImpl::setBlocking(bool blocking) {
return Api::OsSysCallsSingleton::get().setsocketblocking(fd_, blocking);
}

absl::optional<int> IoSocketHandleImpl::domain() {
sockaddr_storage addr;
socklen_t len = sizeof(addr);
Api::SysCallIntResult result;

result = Api::OsSysCallsSingleton::get().getsockname(
fd_, reinterpret_cast<struct sockaddr*>(&addr), &len);

if (result.rc_ == 0) {
return {addr.ss_family};
}

return absl::nullopt;
}
absl::optional<int> IoSocketHandleImpl::domain() { return domain_; }

Address::InstanceConstSharedPtr IoSocketHandleImpl::localAddress() {
sockaddr_storage ss;
Expand All @@ -463,12 +474,8 @@ Address::InstanceConstSharedPtr IoSocketHandleImpl::peerAddress() {
throw EnvoyException(
fmt::format("getpeername failed for '{}': {}", fd_, errorDetails(result.errno_)));
}
#ifdef __APPLE__
if (ss_len == sizeof(sockaddr) && ss.ss_family == AF_UNIX)
#else
if (ss_len == sizeof(sa_family_t) && ss.ss_family == AF_UNIX)
#endif
{

if (ss_len == udsAddressLength() && ss.ss_family == AF_UNIX) {
// For Unix domain sockets, can't find out the peer name, but it should match our own
// name for the socket (i.e. the path should match, barring any namespace or other
// mechanisms to hide things, of which there are many).
Expand Down
6 changes: 4 additions & 2 deletions source/common/network/io_socket_handle_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ namespace Network {
*/
class IoSocketHandleImpl : public IoHandle, protected Logger::Loggable<Logger::Id::io> {
public:
explicit IoSocketHandleImpl(os_fd_t fd = INVALID_SOCKET, bool socket_v6only = false)
: fd_(fd), socket_v6only_(socket_v6only) {}
explicit IoSocketHandleImpl(os_fd_t fd = INVALID_SOCKET, bool socket_v6only = false,
absl::optional<int> domain = absl::nullopt)
: fd_(fd), socket_v6only_(socket_v6only), domain_(domain) {}

// Close underlying socket if close() hasn't been call yet.
~IoSocketHandleImpl() override;
Expand Down Expand Up @@ -85,6 +86,7 @@ class IoSocketHandleImpl : public IoHandle, protected Logger::Loggable<Logger::I

os_fd_t fd_;
int socket_v6only_{false};
const absl::optional<int> domain_;

// The minimum cmsg buffer size to filled in destination address, packets dropped and gso
// size when receiving a packet. It is possible for a received packet to contain both IPv4
Expand Down
1 change: 0 additions & 1 deletion source/common/network/socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ SocketImpl::SocketImpl(IoHandlePtr&& io_handle,
}

auto domain = io_handle_->domain();

// This should never happen in practice but too many tests inject fake fds ...
if (!domain.has_value()) {
return;
Expand Down
7 changes: 4 additions & 3 deletions source/common/network/socket_interface_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
namespace Envoy {
namespace Network {

IoHandlePtr SocketInterfaceImpl::makeSocket(int socket_fd, bool socket_v6only) const {
return std::make_unique<IoSocketHandleImpl>(socket_fd, socket_v6only);
IoHandlePtr SocketInterfaceImpl::makeSocket(int socket_fd, bool socket_v6only,
absl::optional<int> domain) const {
return std::make_unique<IoSocketHandleImpl>(socket_fd, socket_v6only, domain);
}

IoHandlePtr SocketInterfaceImpl::socket(Socket::Type socket_type, Address::Type addr_type,
Expand Down Expand Up @@ -48,7 +49,7 @@ IoHandlePtr SocketInterfaceImpl::socket(Socket::Type socket_type, Address::Type
const Api::SysCallSocketResult result = Api::OsSysCallsSingleton::get().socket(domain, flags, 0);
RELEASE_ASSERT(SOCKET_VALID(result.rc_),
fmt::format("socket(2) failed, got error: {}", errorDetails(result.errno_)));
IoHandlePtr io_handle = makeSocket(result.rc_, socket_v6only);
IoHandlePtr io_handle = makeSocket(result.rc_, socket_v6only, domain);

#if defined(__APPLE__) || defined(WIN32)
// Cannot set SOCK_NONBLOCK as a ::socket flag.
Expand Down
3 changes: 2 additions & 1 deletion source/common/network/socket_interface_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class SocketInterfaceImpl : public SocketInterfaceBase {
};

protected:
virtual IoHandlePtr makeSocket(int socket_fd, bool socket_v6only) const;
virtual IoHandlePtr makeSocket(int socket_fd, bool socket_v6only,
absl::optional<int> domain) const;
};

DECLARE_FACTORY(SocketInterfaceImpl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ TEST_F(QuicIoHandleWrapperTest, DelegateIoHandleCalls) {
EXPECT_CALL(os_sys_calls_, sendmsg(fd, _, 0)).WillOnce(Return(Api::SysCallSizeResult{5u, 0}));
wrapper_->sendmsg(&slice, 1, 0, /*self_ip=*/nullptr, *addr);

EXPECT_CALL(os_sys_calls_, getsockname(_, _, _)).WillOnce(Return(Api::SysCallIntResult{0, 0}));
wrapper_->domain();

EXPECT_CALL(os_sys_calls_, getsockname(_, _, _))
Expand Down
1 change: 0 additions & 1 deletion test/integration/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,6 @@ envoy_cc_test(
"uds_integration_test.cc",
"uds_integration_test.h",
],
tags = ["fails_on_windows"],
deps = [
":http_integration_lib",
"//source/common/event:dispatcher_includes",
Expand Down
9 changes: 6 additions & 3 deletions test/integration/filters/test_socket_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ IoHandlePtr TestIoSocketHandle::accept(struct sockaddr* addr, socklen_t* addrlen
return nullptr;
}

return std::make_unique<TestIoSocketHandle>(writev_override_, result.rc_, socket_v6only_);
return std::make_unique<TestIoSocketHandle>(writev_override_, result.rc_, socket_v6only_,
domain_);
}

IoHandlePtr TestSocketInterface::makeSocket(int socket_fd, bool socket_v6only) const {
return std::make_unique<TestIoSocketHandle>(writev_override_proc_, socket_fd, socket_v6only);
IoHandlePtr TestSocketInterface::makeSocket(int socket_fd, bool socket_v6only,
absl::optional<int> domain) const {
return std::make_unique<TestIoSocketHandle>(writev_override_proc_, socket_fd, socket_v6only,
domain);
}

} // namespace Network
Expand Down
7 changes: 4 additions & 3 deletions test/integration/filters/test_socket_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class TestIoSocketHandle : public IoSocketHandleImpl {
using WritevOverrideProc = std::function<WritevOverrideType>;

TestIoSocketHandle(WritevOverrideProc writev_override_proc, os_fd_t fd = INVALID_SOCKET,
bool socket_v6only = false)
: IoSocketHandleImpl(fd, socket_v6only), writev_override_(writev_override_proc) {}
bool socket_v6only = false, absl::optional<int> domain = absl::nullopt)
: IoSocketHandleImpl(fd, socket_v6only, domain), writev_override_(writev_override_proc) {}

private:
IoHandlePtr accept(struct sockaddr* addr, socklen_t* addrlen) override;
Expand Down Expand Up @@ -57,7 +57,8 @@ class TestSocketInterface : public SocketInterfaceImpl {

private:
// SocketInterfaceImpl
IoHandlePtr makeSocket(int socket_fd, bool socket_v6only) const override;
IoHandlePtr makeSocket(int socket_fd, bool socket_v6only,
absl::optional<int> domain) const override;

const TestIoSocketHandle::WritevOverrideProc writev_override_proc_;
};
Expand Down