diff --git a/source/common/network/io_uring_socket_handle_impl.cc b/source/common/network/io_uring_socket_handle_impl.cc index 904bd651fcefd..a5835b41c8db8 100644 --- a/source/common/network/io_uring_socket_handle_impl.cc +++ b/source/common/network/io_uring_socket_handle_impl.cc @@ -70,32 +70,18 @@ bool IoUringSocketHandleImpl::isOpen() const { return SOCKET_VALID(fd_); } Api::IoCallUint64Result IoUringSocketHandleImpl::readv(uint64_t max_length, Buffer::RawSlice* slices, uint64_t num_slice) { - if (remote_closed_) { - return Api::ioCallUint64ResultNoError(); - } - - if (bytes_to_read_ < 0) { - return {0, Api::IoErrorPtr(new IoSocketError(-bytes_to_read_), IoSocketError::deleteIoError)}; - } - - if (bytes_to_read_ == 0 || read_req_ == nullptr) { - addReadRequest(); - return {0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(), - IoSocketError::deleteIoError)}; - } - - const uint64_t max_read_length = std::min(max_length, static_cast(bytes_to_read_)); - uint64_t num_bytes_to_read = read_buf_.copyOutToSlices(max_read_length, slices, num_slice); - ASSERT(num_bytes_to_read <= max_read_length); - read_buf_.drain(num_bytes_to_read); - bytes_to_read_ -= num_bytes_to_read; - if (bytes_to_read_ == 0) { - bytes_to_read_ = 0; - read_req_ = nullptr; - addReadRequest(); + Api::IoCallUint64Result result = copyOut(max_length, slices, num_slice); + if (result.ok()) { + read_buf_.drain(result.return_value_); + bytes_to_read_ -= result.return_value_; + if (bytes_to_read_ == 0) { + bytes_to_read_ = 0; + read_req_ = nullptr; + addReadRequest(); + } } - return {num_bytes_to_read, Api::IoErrorPtr(nullptr, IoSocketError::deleteIoError)}; + return result; } Api::IoCallUint64Result IoUringSocketHandleImpl::read(Buffer::Instance& buffer, @@ -193,9 +179,18 @@ Api::IoCallUint64Result IoUringSocketHandleImpl::recvmmsg(RawSliceArrays& /*slic PANIC("not implemented"); } -Api::IoCallUint64Result IoUringSocketHandleImpl::recv(void* /*buffer*/, size_t /*length*/, - int /*flags*/) { - PANIC("not implemented"); +Api::IoCallUint64Result IoUringSocketHandleImpl::recv(void* buffer, size_t length, int flags) { + Buffer::RawSlice slice; + slice.mem_ = buffer; + slice.len_ = length; + switch (flags) { + case 0: + return readv(length, &slice, 1); + case MSG_PEEK: + return copyOut(length, &slice, 1); + default: + PANIC("not implemented"); + } } bool IoUringSocketHandleImpl::supportsMmsg() const { PANIC("not implemented"); } @@ -347,6 +342,30 @@ Api::SysCallIntResult IoUringSocketHandleImpl::shutdown(int how) { return Api::OsSysCallsSingleton::get().shutdown(fd_, how); } +Api::IoCallUint64Result IoUringSocketHandleImpl::copyOut(uint64_t max_length, + Buffer::RawSlice* slices, + uint64_t num_slice) { + if (remote_closed_) { + return Api::ioCallUint64ResultNoError(); + } + + if (bytes_to_read_ < 0) { + return {0, Api::IoErrorPtr(new IoSocketError(-bytes_to_read_), IoSocketError::deleteIoError)}; + } + + if (bytes_to_read_ == 0 || read_req_ == nullptr) { + addReadRequest(); + return {0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(), + IoSocketError::deleteIoError)}; + } + + const uint64_t max_read_length = std::min(max_length, static_cast(bytes_to_read_)); + uint64_t num_bytes_to_read = read_buf_.copyOutToSlices(max_read_length, slices, num_slice); + ASSERT(num_bytes_to_read <= max_read_length); + + return {num_bytes_to_read, Api::IoErrorPtr(nullptr, IoSocketError::deleteIoError)}; +} + void IoUringSocketHandleImpl::addReadRequest() { if (!is_read_enabled_ || !SOCKET_VALID(fd_) || read_req_) { return; diff --git a/source/common/network/io_uring_socket_handle_impl.h b/source/common/network/io_uring_socket_handle_impl.h index 9e733fdbb504d..060f194c367cf 100644 --- a/source/common/network/io_uring_socket_handle_impl.h +++ b/source/common/network/io_uring_socket_handle_impl.h @@ -107,6 +107,8 @@ class IoUringSocketHandleImpl final : public IoHandle, protected Logger::Loggabl socklen_t remote_addr_len_{sizeof(remote_addr_)}; }; + Api::IoCallUint64Result copyOut(uint64_t max_length, Buffer::RawSlice* slices, + uint64_t num_slice); void addReadRequest(); // Checks if the io handle is the one that registered eventfd with `io_uring`. // An io handle can be a leader in two cases: