From b1a09a3577ea89e856cd1330525956c0a506d842 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Fri, 15 Mar 2024 01:46:39 +0330 Subject: [PATCH 1/2] use Mutex instead of atomic boolean and spin --- tokio/src/io/split.rs | 71 ++++++++++++------------------------------- 1 file changed, 19 insertions(+), 52 deletions(-) diff --git a/tokio/src/io/split.rs b/tokio/src/io/split.rs index 63f0960e4f3..7a432d50b51 100644 --- a/tokio/src/io/split.rs +++ b/tokio/src/io/split.rs @@ -10,9 +10,8 @@ use std::cell::UnsafeCell; use std::fmt; use std::io; use std::pin::Pin; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::{Acquire, Release}; use std::sync::Arc; +use std::sync::Mutex; use std::task::{Context, Poll}; cfg_io_util! { @@ -38,8 +37,7 @@ cfg_io_util! { let is_write_vectored = stream.is_write_vectored(); let inner = Arc::new(Inner { - locked: AtomicBool::new(false), - stream: UnsafeCell::new(stream), + stream: Mutex::new(UnsafeCell::new(stream)), is_write_vectored, }); @@ -54,13 +52,19 @@ cfg_io_util! { } struct Inner { - locked: AtomicBool, - stream: UnsafeCell, + stream: Mutex>, is_write_vectored: bool, } -struct Guard<'a, T> { - inner: &'a Inner, +impl Inner { + fn with_lock(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R { + let mut guard = self.stream.lock().unwrap(); + + // safety: we do not move the stream. + let stream = unsafe { Pin::new_unchecked(guard.get_mut()) }; + + f(stream) + } } impl ReadHalf { @@ -90,7 +94,7 @@ impl ReadHalf { .ok() .expect("`Arc::try_unwrap` failed"); - inner.stream.into_inner() + inner.stream.into_inner().unwrap().into_inner() } else { panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") } @@ -111,8 +115,7 @@ impl AsyncRead for ReadHalf { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_read(cx, buf) + self.inner.with_lock(|stream| stream.poll_read(cx, buf)) } } @@ -122,18 +125,15 @@ impl AsyncWrite for WriteHalf { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_write(cx, buf) + self.inner.with_lock(|stream| stream.poll_write(cx, buf)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_flush(cx) + self.inner.with_lock(|stream| stream.poll_flush(cx)) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_shutdown(cx) + self.inner.with_lock(|stream| stream.poll_shutdown(cx)) } fn poll_write_vectored( @@ -141,8 +141,8 @@ impl AsyncWrite for WriteHalf { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_write_vectored(cx, bufs) + self.inner + .with_lock(|stream| stream.poll_write_vectored(cx, bufs)) } fn is_write_vectored(&self) -> bool { @@ -150,39 +150,6 @@ impl AsyncWrite for WriteHalf { } } -impl Inner { - fn poll_lock(&self, cx: &mut Context<'_>) -> Poll> { - if self - .locked - .compare_exchange(false, true, Acquire, Acquire) - .is_ok() - { - Poll::Ready(Guard { inner: self }) - } else { - // Spin... but investigate a better strategy - - std::thread::yield_now(); - cx.waker().wake_by_ref(); - - Poll::Pending - } - } -} - -impl Guard<'_, T> { - fn stream_pin(&mut self) -> Pin<&mut T> { - // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual - // exclusion. - unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } - } -} - -impl Drop for Guard<'_, T> { - fn drop(&mut self) { - self.inner.locked.store(false, Release); - } -} - unsafe impl Send for ReadHalf {} unsafe impl Send for WriteHalf {} unsafe impl Sync for ReadHalf {} From 4434730d854d9c9b81421de051ed47d455b87a31 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Fri, 15 Mar 2024 12:27:35 +0330 Subject: [PATCH 2/2] remove unnecessary use of UnsafeCell --- tokio/src/io/split.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tokio/src/io/split.rs b/tokio/src/io/split.rs index 7a432d50b51..2602929cdd1 100644 --- a/tokio/src/io/split.rs +++ b/tokio/src/io/split.rs @@ -6,7 +6,6 @@ use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; -use std::cell::UnsafeCell; use std::fmt; use std::io; use std::pin::Pin; @@ -37,7 +36,7 @@ cfg_io_util! { let is_write_vectored = stream.is_write_vectored(); let inner = Arc::new(Inner { - stream: Mutex::new(UnsafeCell::new(stream)), + stream: Mutex::new(stream), is_write_vectored, }); @@ -52,7 +51,7 @@ cfg_io_util! { } struct Inner { - stream: Mutex>, + stream: Mutex, is_write_vectored: bool, } @@ -61,7 +60,7 @@ impl Inner { let mut guard = self.stream.lock().unwrap(); // safety: we do not move the stream. - let stream = unsafe { Pin::new_unchecked(guard.get_mut()) }; + let stream = unsafe { Pin::new_unchecked(&mut *guard) }; f(stream) } @@ -94,7 +93,7 @@ impl ReadHalf { .ok() .expect("`Arc::try_unwrap` failed"); - inner.stream.into_inner().unwrap().into_inner() + inner.stream.into_inner().unwrap() } else { panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") }