Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions source/common/network/io_uring_socket_handle_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ IoUringSocketHandleImpl::~IoUringSocketHandleImpl() {
Api::IoCallUint64Result IoUringSocketHandleImpl::close() {
ASSERT(SOCKET_VALID(fd_));
auto req = new Request{absl::nullopt, RequestType::Close};
req->closed_ = true;

// put the request into the map.
req->id_ = global_request_id++;
req->fd_ = fd_;
request_map_.insert({req->id_, req});

for (auto& req : request_map_) {
if (req.second->fd_ == fd_) {
req.second->closed_ = true;
}
}

Io::IoUringResult res = io_uring_factory_.get().ref().prepareClose(fd_, req);
if (res == Io::IoUringResult::Failed) {
// Fall back to posix system call.
Expand All @@ -62,6 +75,15 @@ Api::IoCallUint64Result IoUringSocketHandleImpl::readv(uint64_t /* max_length */
return {0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(),
IoSocketError::deleteIoError)};
}

if (remote_closed_) {
return Api::ioCallUint64ResultNoError();
}

if (bytes_to_read_ < 0) {
return {0, Api::IoErrorPtr(new IoSocketError(bytes_to_read_), IoSocketError::deleteIoError)};
}

uint64_t num_slices_to_read = 0;
uint64_t num_bytes_to_read = 0;
for (;
Expand Down Expand Up @@ -190,6 +212,12 @@ IoHandlePtr IoUringSocketHandleImpl::accept(struct sockaddr* addr, socklen_t* ad
Api::SysCallIntResult IoUringSocketHandleImpl::connect(Address::InstanceConstSharedPtr address) {
auto& uring = io_uring_factory_.get().ref();
auto req = new Request{*this, RequestType::Connect};

// put the request into the map.
req->id_ = global_request_id++;
req->fd_ = fd_;
request_map_.insert({req->id_, req});

auto res = uring.prepareConnect(fd_, address, req);
if (res == Io::IoUringResult::Failed) {
res = uring.submit();
Expand Down Expand Up @@ -327,6 +355,12 @@ void IoUringSocketHandleImpl::addReadRequest() {
iov_.iov_len = read_buffer_size_;
auto& uring = io_uring_factory_.get().ref();
auto req = new Request{*this, RequestType::Read};

// put the request into the map.
req->id_ = global_request_id++;
req->fd_ = fd_;
request_map_.insert({req->id_, req});

auto res = uring.prepareReadv(fd_, &iov_, 1, 0, req);
if (res == Io::IoUringResult::Failed) {
// TODO(rojkov): handle `EBUSY` in case the completion queue is never reaped.
Expand All @@ -353,6 +387,12 @@ void IoUringSocketHandleImpl::addWriteRequest() {
}

auto req = new Request{*this, RequestType::Write, iovecs, std::move(write_buf_)};

// put the request into the map.
req->id_ = global_request_id++;
req->fd_ = fd_;
request_map_.insert({req->id_, req});

write_buf_ = std::list<Buffer::SliceDataPtr>{};
auto& uring = io_uring_factory_.get().ref();
auto res = uring.prepareWritev(fd_, iovecs, nr_vecs, 0, req);
Expand Down Expand Up @@ -485,6 +525,11 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
ENVOY_LOG(debug, "async request failed: {}", errorDetails(-result));
}

if (req.closed_) {
printf("IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion, request is closed, ret = %d\n", result);
return;
}

switch (req.type_) {
case RequestType::Accept:
ASSERT(!SOCKET_VALID(connection_fd_));
Expand All @@ -504,6 +549,12 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
break;
}

// put the request into the map.
auto iter = iohandle.request_map_.find(req.id_);
if (iter != iohandle.request_map_.end()) {
iohandle.request_map_.erase(iter);
}

if (result == 0) {
iohandle.remote_closed_ = true;
}
Expand All @@ -513,11 +564,24 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
}
break;
}
case RequestType::Connect:
case RequestType::Connect: {
ASSERT(req.iohandle_.has_value());
printf("IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion, Connected, fd = %d, ret = %d\n", req.iohandle_->get().fd_, result);
if (req.iohandle_->get().fd_ == -1) {
ENVOY_LOG_MISC(debug, "the uring's fd already close, we got -1 fd in connect request");
break;
}

auto& iohandle = req.iohandle_->get();
auto iter = iohandle.request_map_.find(req.id_);
if (iter != iohandle.request_map_.end()) {
iohandle.request_map_.erase(iter);
}

req.iohandle_->get().cb_(result < 0 ? Event::FileReadyType::Closed
: Event::FileReadyType::Write);
break;
}
case RequestType::Write: {
ASSERT(req.iov_ != nullptr);
ASSERT(req.iohandle_.has_value());
Expand All @@ -529,6 +593,11 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
break;
}

auto iter = iohandle.request_map_.find(req.id_);
if (iter != iohandle.request_map_.end()) {
iohandle.request_map_.erase(iter);
}

if (result < 0) {
delete[] req.iov_;
iohandle.cb_(Event::FileReadyType::Closed);
Expand All @@ -537,8 +606,16 @@ void IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion(const Reques
}
break;
}
case RequestType::Close:
case RequestType::Close: {
printf("IoUringSocketHandleImpl::FileEventAdapter::onRequestCompletion, Close, ret = %d\n", result);

auto iter = req.iohandle_->get().request_map_.find(req.id_);
if (iter != req.iohandle_->get().request_map_.end()) {
req.iohandle_->get().request_map_.erase(iter);
}

break;
}
default:
PANIC("not implemented");
}
Expand Down
5 changes: 5 additions & 0 deletions source/common/network/io_uring_socket_handle_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ struct Request {
RequestType type_{RequestType::Unknown};
struct iovec* iov_{nullptr};
std::list<Buffer::SliceDataPtr> slices_{};
uint64_t id_{0};
os_fd_t fd_{-1};
bool closed_{false};
};

/**
Expand Down Expand Up @@ -133,6 +136,8 @@ class IoUringSocketHandleImpl final : public IoHandle, protected Logger::Loggabl
bool is_write_added_{false};
std::unique_ptr<FileEventAdapter> file_event_adapter_{nullptr};
bool remote_closed_{false};
uint64_t global_request_id{0};
absl::flat_hash_map<uint64_t, Request*> request_map_;
};

} // namespace Network
Expand Down