Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 64 additions & 63 deletions source/common/network/io_uring_socket_handle_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "envoy/event/dispatcher.h"

#include "source/common/api/os_sys_calls_impl.h"
#include "source/common/buffer/buffer_impl.h"
#include "source/common/common/assert.h"
#include "source/common/common/utility.h"
#include "source/common/io/io_uring.h"
Expand Down Expand Up @@ -69,60 +68,50 @@ Api::IoCallUint64Result IoUringSocketHandleImpl::close() {

bool IoUringSocketHandleImpl::isOpen() const { return SOCKET_VALID(fd_); }

Api::IoCallUint64Result IoUringSocketHandleImpl::readv(uint64_t /* max_length */,
Buffer::RawSlice* slices,
uint64_t num_slice) {
if (bytes_to_read_ == 0 || read_buf_ == nullptr) {
Api::IoCallUint64Result
IoUringSocketHandleImpl::readv(uint64_t max_length, Buffer::RawSlice* slices, uint64_t num_slice) {
if (remote_closed_) {
return Api::ioCallUint64ResultNoError();
}

if (bytes_to_read_ < 0) {
return {0, Api::IoErrorPtr(new IoSocketError(-bytes_to_read_), IoSocketError::deleteIoError)};
}

if (bytes_to_read_ == 0 || read_req_ == nullptr) {
addReadRequest();
return {0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(),
IoSocketError::deleteIoError)};
}
uint64_t num_slices_to_read = 0;
uint64_t num_bytes_to_read = 0;
for (;
num_slices_to_read < num_slice && num_bytes_to_read < static_cast<uint64_t>(bytes_to_read_);
num_slices_to_read++) {
const size_t slice_length = std::min(slices[num_slices_to_read].len_,
static_cast<size_t>(bytes_to_read_ - num_bytes_to_read));
memcpy(slices[num_slices_to_read].mem_, read_buf_.get() + num_bytes_to_read, slice_length);
num_bytes_to_read += slice_length;
}
ASSERT(num_bytes_to_read <= static_cast<uint64_t>(bytes_to_read_));
read_req_ = nullptr;

uint64_t len = bytes_to_read_;
bytes_to_read_ = 0;
return {len, Api::IoErrorPtr(nullptr, IoSocketError::deleteIoError)};

const uint64_t max_read_length = std::min(max_length, static_cast<uint64_t>(bytes_to_read_));
uint64_t num_bytes_to_read = read_buf_.copyOutToSlices(max_read_length, slices, num_slice);
ASSERT(num_bytes_to_read <= max_read_length);
read_buf_.drain(num_bytes_to_read);
bytes_to_read_ -= num_bytes_to_read;
if (bytes_to_read_ == 0) {
bytes_to_read_ = 0;
read_req_ = nullptr;
addReadRequest();
}

return {num_bytes_to_read, Api::IoErrorPtr(nullptr, IoSocketError::deleteIoError)};
}

Api::IoCallUint64Result IoUringSocketHandleImpl::read(Buffer::Instance& buffer,
absl::optional<uint64_t> max_length_opt) {
const uint64_t max_length = max_length_opt.value_or(UINT64_MAX);
if (max_length == 0 || remote_closed_) {
if (max_length == 0) {
return Api::ioCallUint64ResultNoError();
}

if (bytes_to_read_ < 0) {
return {0, Api::IoErrorPtr(new IoSocketError(bytes_to_read_), IoSocketError::deleteIoError)};
}

if (bytes_to_read_ == 0 || read_buf_ == nullptr) {
addReadRequest();
return {0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(),
IoSocketError::deleteIoError)};
}
auto fragment = new Buffer::BufferFragmentImpl(
read_buf_.release(), bytes_to_read_,
[](const void* data, size_t /*len*/, const Buffer::BufferFragmentImpl* this_fragment) {
delete[] reinterpret_cast<const uint8_t*>(data);
delete this_fragment;
});
buffer.addBufferFragment(*fragment);
read_req_ = nullptr;

uint64_t len = bytes_to_read_;
bytes_to_read_ = 0;
return {len, Api::IoErrorPtr(nullptr, IoSocketError::deleteIoError)};
Buffer::Reservation reservation = buffer.reserveForRead();
Api::IoCallUint64Result result = readv(std::min(reservation.length(), max_length),
reservation.slices(), reservation.numSlices());
uint64_t bytes_to_commit = result.ok() ? result.return_value_ : 0;
ASSERT(bytes_to_commit <= max_length);
reservation.commit(bytes_to_commit);
return result;
}

Api::IoCallUint64Result IoUringSocketHandleImpl::writev(const Buffer::RawSlice* slices,
Expand All @@ -132,13 +121,14 @@ Api::IoCallUint64Result IoUringSocketHandleImpl::writev(const Buffer::RawSlice*
IoSocketError::deleteIoError)};
}

if (bytes_to_write_ < 0) {
return {0, Api::IoErrorPtr(new IoSocketError(bytes_to_read_), IoSocketError::deleteIoError)};
if (bytes_already_wrote_ < 0) {
return {
0, Api::IoErrorPtr(new IoSocketError(-bytes_already_wrote_), IoSocketError::deleteIoError)};
}

if (bytes_to_write_ > 0) {
uint64_t len = bytes_to_write_;
bytes_to_write_ = 0;
if (bytes_already_wrote_ > 0) {
uint64_t len = bytes_already_wrote_;
bytes_already_wrote_ = 0;
return {len, Api::IoErrorPtr(nullptr, IoSocketError::deleteIoError)};
}

Expand Down Expand Up @@ -173,8 +163,10 @@ Api::IoCallUint64Result IoUringSocketHandleImpl::writev(const Buffer::RawSlice*
}

Api::IoCallUint64Result IoUringSocketHandleImpl::write(Buffer::Instance& buffer) {
if (bytes_to_write_ > 0) {
buffer.drain(static_cast<uint64_t>(bytes_to_write_));
// If buffer gets written and drained, the following writev will return bytes_already_wrote_
// directly.
if (bytes_already_wrote_ > 0) {
buffer.drain(static_cast<uint64_t>(bytes_already_wrote_));
}

Buffer::RawSliceVector slices = buffer.getRawSlices();
Expand Down Expand Up @@ -359,17 +351,17 @@ void IoUringSocketHandleImpl::addReadRequest() {
return;
}

ASSERT(read_buf_ == nullptr);
read_buf_ = std::unique_ptr<uint8_t[]>(new uint8_t[read_buffer_size_]);
iov_.iov_base = read_buf_.get();
iov_.iov_len = read_buffer_size_;
auto& uring = io_uring_factory_.get().ref();
read_req_ = new Request{*this, RequestType::Read};
auto res = uring.prepareReadv(fd_, &iov_, 1, 0, read_req_);
read_req_->buf_ = std::make_unique<uint8_t[]>(read_buffer_size_);
read_req_->iov_ = new struct iovec[1];
read_req_->iov_->iov_base = read_req_->buf_.get();
read_req_->iov_->iov_len = read_buffer_size_;
auto& uring = io_uring_factory_.get().ref();
auto res = uring.prepareReadv(fd_, read_req_->iov_, 1, 0, read_req_);
if (res == Io::IoUringResult::Failed) {
// TODO(rojkov): handle `EBUSY` in case the completion queue is never reaped.
uring.submit();
res = uring.prepareReadv(fd_, &iov_, 1, 0, read_req_);
res = uring.prepareReadv(fd_, read_req_->iov_, 1, 0, read_req_);
RELEASE_ASSERT(res == Io::IoUringResult::Ok, "unable to prepare readv");
}
}
Expand Down Expand Up @@ -469,25 +461,31 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
// This is hacky fix, we should check the req is valid or not.
if (iohandle.fd_ == -1) {
ENVOY_LOG_MISC(debug, "the uring's fd already closed");
break;
return;
}

iohandle.bytes_to_read_ = result;
if (result == 0) {
iohandle.remote_closed_ = true;
}
iohandle.cb_(Event::FileReadyType::Read);
if (result > 0) {
iohandle.addReadRequest();
Buffer::BufferFragment* fragment = new Buffer::BufferFragmentImpl(
const_cast<Request&>(req).buf_.release(), result,
[](const void* data, size_t /*len*/, const Buffer::BufferFragmentImpl* this_fragment) {
delete[] reinterpret_cast<const uint8_t*>(data);
delete this_fragment;
});
iohandle.read_buf_.addBufferFragment(*fragment);
}
iohandle.cb_(Event::FileReadyType::Read);
break;
}
case RequestType::Connect: {
ASSERT(req.iohandle_.has_value());
auto& iohandle = req.iohandle_->get();
if (result < 0) {
iohandle.cb_(Event::FileReadyType::Closed);
break;
return;
}

iohandle.cb_(Event::FileReadyType::Write);
Expand All @@ -500,10 +498,10 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
// This is hacky fix, we should check the req is valid or not.
if (iohandle.fd_ == -1) {
ENVOY_LOG_MISC(debug, "the uring's fd already closed");
break;
return;
}

iohandle.bytes_to_write_ = result;
iohandle.bytes_already_wrote_ = result;
iohandle.is_write_added_ = false;
iohandle.cb_(Event::FileReadyType::Write);
break;
Expand All @@ -522,6 +520,9 @@ void IoUringSocketHandleImpl::FileEventAdapter::onFileEvent() {
uring.forEveryCompletion([this](void* user_data, int32_t result) {
auto req = static_cast<Request*>(user_data);
onRequestCompletion(*req, result);
if (req->iov_) {
delete[] req->iov_;
}
delete req;
});
uring.submit();
Expand Down
7 changes: 4 additions & 3 deletions source/common/network/io_uring_socket_handle_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "envoy/buffer/buffer.h"
#include "envoy/network/io_handle.h"

#include "source/common/buffer/buffer_impl.h"
#include "source/common/common/logger.h"
#include "source/common/io/io_uring.h"

Expand All @@ -21,6 +22,7 @@ struct Request {
IoUringSocketHandleImplOptRef iohandle_{absl::nullopt};
RequestType type_{RequestType::Unknown};
struct iovec* iov_{nullptr};
std::unique_ptr<uint8_t[]> buf_{};
};

/**
Expand Down Expand Up @@ -120,12 +122,11 @@ class IoUringSocketHandleImpl final : public IoHandle, protected Logger::Loggabl
const absl::optional<int> domain_;

Event::FileReadyCb cb_;
struct iovec iov_;
std::unique_ptr<uint8_t[]> read_buf_{nullptr};
Buffer::OwnedImpl read_buf_;
int32_t bytes_to_read_{0};
Request* read_req_{nullptr};
bool is_read_enabled_{true};
int32_t bytes_to_write_{0};
int32_t bytes_already_wrote_{0};
bool is_write_added_{false};
std::unique_ptr<FileEventAdapter> file_event_adapter_{nullptr};
bool remote_closed_{false};
Expand Down