Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of initialized bytes in read_to_end #3426

Merged
merged 6 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ cfg_io_util! {
mod read_line;

mod read_to_end;
mod vec_with_initialized;
cfg_process! {
pub(crate) use read_to_end::read_to_end;
}
Expand Down Expand Up @@ -82,6 +83,7 @@ cfg_io_util! {

cfg_not_io_util! {
cfg_process! {
mod vec_with_initialized;
mod read_to_end;
// Used by process
pub(crate) use read_to_end::read_to_end;
Expand Down
68 changes: 33 additions & 35 deletions tokio/src/io/util/read_to_end.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::io::{AsyncRead, ReadBuf};
use crate::io::util::vec_with_initialized::{into_read_buf_parts, VecWithInitialized};
use crate::io::AsyncRead;

use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::marker::PhantomPinned;
use std::mem::{self, MaybeUninit};
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};

Expand All @@ -13,7 +14,7 @@ pin_project! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadToEnd<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut Vec<u8>,
buf: VecWithInitialized<&'a mut Vec<u8>>,
// The number of bytes appended to buf. This can be less than buf.len() if
// the buffer was not empty when the operation was started.
read: usize,
Expand All @@ -27,22 +28,22 @@ pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) ->
where
R: AsyncRead + Unpin + ?Sized,
{
// SAFETY: The generic type on VecWithInitialized is &mut Vec<u8>.
ReadToEnd {
reader,
buf: buffer,
buf: unsafe { VecWithInitialized::new(buffer) },
read: 0,
_pin: PhantomPinned,
}
}

pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>(
buf: &mut Vec<u8>,
pub(super) fn read_to_end_internal<V: AsMut<Vec<u8>>, R: AsyncRead + ?Sized>(
buf: &mut VecWithInitialized<V>,
mut reader: Pin<&mut R>,
num_read: &mut usize,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
loop {
// safety: The caller promised to prepare the buffer.
let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx));
match ret {
Err(err) => return Poll::Ready(Err(err)),
Expand All @@ -57,8 +58,8 @@ pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>(
/// Tries to read from the provided AsyncRead.
///
/// The length of the buffer is increased by the number of bytes read.
fn poll_read_to_end<R: AsyncRead + ?Sized>(
buf: &mut Vec<u8>,
fn poll_read_to_end<V: AsMut<Vec<u8>>, R: AsyncRead + ?Sized>(
buf: &mut VecWithInitialized<V>,
read: Pin<&mut R>,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
Expand All @@ -68,37 +69,34 @@ fn poll_read_to_end<R: AsyncRead + ?Sized>(
// of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every
// time is 4,500 times (!) slower than this if the reader has a very small
// amount of data to return.
reserve(buf, 32);
buf.reserve(32);

let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf));
// Get a ReadBuf into the vector.
let mut read_buf = buf.get_read_buf();

let ptr = unused_capacity.filled().as_ptr();
ready!(read.poll_read(cx, &mut unused_capacity))?;
assert_eq!(ptr, unused_capacity.filled().as_ptr());
let filled_before = read_buf.filled().len();
let poll_result = read.poll_read(cx, &mut read_buf);
let filled_after = read_buf.filled().len();
let n = filled_after - filled_before;

let n = unused_capacity.filled().len();
let new_len = buf.len() + n;
// Update the length of the vector using the result of poll_read.
let read_buf_parts = into_read_buf_parts(read_buf);
buf.apply_read_buf(read_buf_parts);

assert!(new_len <= buf.capacity());
unsafe {
buf.set_len(new_len);
}
Poll::Ready(Ok(n))
}

/// Allocates more memory and ensures that the unused capacity is prepared for use
/// with the `AsyncRead`.
fn reserve(buf: &mut Vec<u8>, bytes: usize) {
if buf.capacity() - buf.len() >= bytes {
return;
match poll_result {
Poll::Pending => {
// In this case, nothing should have been read. However we still
// update the vector in case the poll_read call initialized parts of
// the vector's unused capacity.
debug_assert_eq!(filled_before, filled_after);
Poll::Pending
}
Poll::Ready(Err(err)) => {
debug_assert_eq!(filled_before, filled_after);
Poll::Ready(Err(err))
}
Poll::Ready(Ok(())) => Poll::Ready(Ok(n)),
}
buf.reserve(bytes);
}

/// Returns the unused capacity of the provided vector.
fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] {
let uninit = bytes::BufMut::chunk_mut(buf);
unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) }
}

impl<A> Future for ReadToEnd<'_, A>
Expand Down
20 changes: 8 additions & 12 deletions tokio/src/io/util/read_to_string.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::io::util::read_line::finish_string_read;
use crate::io::util::read_to_end::read_to_end_internal;
use crate::io::util::vec_with_initialized::VecWithInitialized;
use crate::io::AsyncRead;

use pin_project_lite::pin_project;
Expand All @@ -19,7 +20,7 @@ pin_project! {
// while reading to postpone utf-8 handling until after reading.
output: &'a mut String,
// The actual allocation of the string is moved into this vector instead.
buf: Vec<u8>,
buf: VecWithInitialized<Vec<u8>>,
// The number of bytes appended to buf. This can be less than buf.len() if
// the buffer was not empty when the operation was started.
read: usize,
Expand All @@ -37,29 +38,25 @@ where
R: AsyncRead + ?Sized + Unpin,
{
let buf = mem::replace(string, String::new()).into_bytes();
// SAFETY: The generic type of the VecWithInitialized is Vec<u8>.
ReadToString {
reader,
buf,
buf: unsafe { VecWithInitialized::new(buf) },
output: string,
read: 0,
_pin: PhantomPinned,
}
}

/// # Safety
///
/// Before first calling this method, the unused capacity must have been
/// prepared for use with the provided AsyncRead. This can be done using the
/// `prepare_buffer` function in `read_to_end.rs`.
unsafe fn read_to_string_internal<R: AsyncRead + ?Sized>(
fn read_to_string_internal<R: AsyncRead + ?Sized>(
reader: Pin<&mut R>,
output: &mut String,
buf: &mut Vec<u8>,
buf: &mut VecWithInitialized<Vec<u8>>,
read: &mut usize,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
let io_res = ready!(read_to_end_internal(buf, reader, read, cx));
let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
let utf8_res = String::from_utf8(buf.take());

// At this point both buf and output are empty. The allocation is in utf8_res.

Expand All @@ -77,7 +74,6 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.project();

// safety: The constructor of ReadToString called `prepare_buffer`.
unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) }
read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx)
}
}
120 changes: 120 additions & 0 deletions tokio/src/io/util/vec_with_initialized.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::io::ReadBuf;
use std::mem::MaybeUninit;

/// This struct wraps a `Vec<u8>` or `&mut Vec<u8>`, combining it with a
/// `num_initialized`, which keeps track of the number of initialized bytes
/// in the unused capacity.
///
/// The purpose of this struct is to remember how many bytes were initialized
/// through a `ReadBuf` from call to call.
///
/// This struct has the safety invariant that the first `num_initialized` of the
/// vector's allocation must be initialized at any time.
#[derive(Debug)]
pub(crate) struct VecWithInitialized<V> {
vec: V,
// The number of initialized bytes in the vector.
// Always between `vec.len()` and `vec.capacity()`.
num_initialized: usize,
}

impl VecWithInitialized<Vec<u8>> {
#[cfg(feature = "io-util")]
pub(crate) fn take(&mut self) -> Vec<u8> {
self.num_initialized = 0;
std::mem::take(&mut self.vec)
}
}

impl<V> VecWithInitialized<V>
where
V: AsMut<Vec<u8>>,
{
/// Safety: The generic parameter `V` must be either `Vec<u8>` or `&mut Vec<u8>`.
pub(crate) unsafe fn new(mut vec: V) -> Self {
// SAFETY: The safety invariants of vector guarantee that the bytes up
// to its length are initialized.
Self {
num_initialized: vec.as_mut().len(),
vec,
}
}

pub(crate) fn reserve(&mut self, num_bytes: usize) {
let vec = self.vec.as_mut();
if vec.capacity() - vec.len() >= num_bytes {
return;
}
// SAFETY: Setting num_initialized to `vec.len()` is correct as
// `reserve` does not change the length of the vector.
self.num_initialized = vec.len();
vec.reserve(num_bytes);
}

#[cfg(feature = "io-util")]
pub(crate) fn is_empty(&mut self) -> bool {
self.vec.as_mut().is_empty()
}

pub(crate) fn get_read_buf<'a>(&'a mut self) -> ReadBuf<'a> {
let num_initialized = self.num_initialized;

// SAFETY: Creating the slice is safe because of the safety invariants
// on Vec<u8>. The safety invariants of `ReadBuf` will further guarantee
// that no bytes in the slice are de-initialized.
let vec = self.vec.as_mut();
let len = vec.len();
let cap = vec.capacity();
let ptr = vec.as_mut_ptr().cast::<MaybeUninit<u8>>();
let slice = unsafe { std::slice::from_raw_parts_mut::<'a, MaybeUninit<u8>>(ptr, cap) };

// SAFETY: This is safe because the safety invariants of
// VecWithInitialized say that the first num_initialized bytes must be
// initialized.
let mut read_buf = ReadBuf::uninit(slice);
unsafe {
read_buf.assume_init(num_initialized);
}
read_buf.set_filled(len);

read_buf
}

pub(crate) fn apply_read_buf(&mut self, parts: ReadBufParts) {
let vec = self.vec.as_mut();
assert_eq!(vec.as_ptr(), parts.ptr);

// SAFETY:
// The ReadBufParts really does point inside `self.vec` due to the above
// check, and the safety invariants of `ReadBuf` guarantee that the
// first `parts.initialized` bytes of `self.vec` really have been
// initialized. Additionally, `ReadBuf` guarantees that `parts.len` is
// at most `parts.initialized`, so the first `parts.len` bytes are also
// initialized.
//
// Note that this relies on the fact that `V` is either `Vec<u8>` or
// `&mut Vec<u8>`, so the vector returned by `self.vec.as_mut()` cannot
// change from call to call.
unsafe {
self.num_initialized = parts.initialized;
vec.set_len(parts.len);
}
}
}

pub(crate) struct ReadBufParts {
// Pointer is only used to check that the ReadBuf actually came from the
// right VecWithInitialized.
ptr: *const u8,
len: usize,
initialized: usize,
}

// This is needed to release the borrow on `VecWithInitialized<V>`.
pub(crate) fn into_read_buf_parts(rb: ReadBuf<'_>) -> ReadBufParts {
ReadBufParts {
ptr: rb.filled().as_ptr(),
len: rb.filled().len(),
initialized: rb.initialized().len(),
}
}
65 changes: 64 additions & 1 deletion tokio/tests/io_read_to_end.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::AsyncReadExt;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
use tokio_test::assert_ok;

#[tokio::test]
Expand All @@ -13,3 +15,64 @@ async fn read_to_end() {
assert_eq!(n, 11);
assert_eq!(buf[..], b"hello world"[..]);
}

#[derive(Copy, Clone, Debug)]
enum State {
Initializing,
JustFilling,
Done,
}

struct UninitTest {
num_init: usize,
state: State,
}

impl AsyncRead for UninitTest {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let me = Pin::into_inner(self);
let real_num_init = buf.initialized().len() - buf.filled().len();
assert_eq!(real_num_init, me.num_init, "{:?}", me.state);

match me.state {
State::Initializing => {
buf.initialize_unfilled_to(me.num_init + 2);
buf.advance(1);
me.num_init += 1;

if me.num_init == 24 {
me.state = State::JustFilling;
}
}
State::JustFilling => {
buf.advance(1);
me.num_init -= 1;

if me.num_init == 15 {
// The buffer is resized on next call.
me.num_init = 0;
me.state = State::Done;
}
}
State::Done => { /* .. do nothing .. */ }
}

Poll::Ready(Ok(()))
}
}

#[tokio::test]
async fn read_to_end_uninit() {
let mut buf = Vec::with_capacity(64);
let mut test = UninitTest {
num_init: 0,
state: State::Initializing,
};

test.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf.len(), 33);
}