From ab53bf0c4727ae63c6a3a3d772b7bd837d2f49c3 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Sat, 3 Aug 2024 11:32:50 +0100 Subject: [PATCH] runtime: prevent niche-optimization to avoid triggering miri (#6744) --- tokio/src/future/maybe_done.rs | 52 ++++++++++++++++++ tokio/src/runtime/task/core.rs | 1 + tokio/src/runtime/tests/task.rs | 94 +++++++++++++++++++++++++++++++++ tokio/tests/macros_join.rs | 2 +- 4 files changed, 148 insertions(+), 1 deletion(-) diff --git a/tokio/src/future/maybe_done.rs b/tokio/src/future/maybe_done.rs index 8b270b3a01f..9ae795f7a7f 100644 --- a/tokio/src/future/maybe_done.rs +++ b/tokio/src/future/maybe_done.rs @@ -10,6 +10,7 @@ pin_project! { #[derive(Debug)] #[project = MaybeDoneProj] #[project_replace = MaybeDoneProjReplace] + #[repr(C)] // https://github.com/rust-lang/miri/issues/3780 pub enum MaybeDone { /// A not-yet-completed future. Future { #[pin] future: Fut }, @@ -69,3 +70,54 @@ impl Future for MaybeDone { Poll::Ready(()) } } + +// Test for https://github.com/tokio-rs/tokio/issues/6729 +#[cfg(test)] +mod miri_tests { + use super::maybe_done; + + use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll, Wake}, + }; + + struct ThingAdder<'a> { + thing: &'a mut String, + } + + impl Future for ThingAdder<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + unsafe { + *self.get_unchecked_mut().thing += ", world"; + } + Poll::Pending + } + } + + #[test] + fn maybe_done_miri() { + let mut thing = "hello".to_owned(); + + // The async block is necessary to trigger the miri failure. + #[allow(clippy::redundant_async_block)] + let fut = async move { ThingAdder { thing: &mut thing }.await }; + + let mut fut = maybe_done(fut); + let mut fut = unsafe { Pin::new_unchecked(&mut fut) }; + + let waker = Arc::new(DummyWaker).into(); + let mut ctx = Context::from_waker(&waker); + assert_eq!(fut.as_mut().poll(&mut ctx), Poll::Pending); + assert_eq!(fut.as_mut().poll(&mut ctx), Poll::Pending); + } + + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } +} diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 108b06bc8b6..78977084adb 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -196,6 +196,7 @@ generate_addr_of_methods! { } /// Either the future or the output. +#[repr(C)] // https://github.com/rust-lang/miri/issues/3780 pub(super) enum Stage { Running(T), Finished(super::Result), diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index fc1e4089070..310c69e8b2a 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -223,6 +223,100 @@ fn shutdown_immediately() { }) } +// Test for https://github.com/tokio-rs/tokio/issues/6729 +#[test] +fn spawn_niche_in_task() { + use crate::future::poll_fn; + use std::task::{Context, Poll, Waker}; + + with(|rt| { + let state = Arc::new(Mutex::new(State::new())); + + let mut subscriber = Subscriber::new(Arc::clone(&state), 1); + rt.spawn(async move { + subscriber.wait().await; + subscriber.wait().await; + }); + + rt.spawn(async move { + state.lock().unwrap().set_version(2); + state.lock().unwrap().set_version(0); + }); + + rt.tick_max(10); + assert!(rt.is_empty()); + rt.shutdown(); + }); + + pub(crate) struct Subscriber { + state: Arc>, + observed_version: u64, + waker_key: Option, + } + + impl Subscriber { + pub(crate) fn new(state: Arc>, version: u64) -> Self { + Self { + state, + observed_version: version, + waker_key: None, + } + } + + pub(crate) async fn wait(&mut self) { + poll_fn(|cx| { + self.state + .lock() + .unwrap() + .poll_update(&mut self.observed_version, &mut self.waker_key, cx) + .map(|_| ()) + }) + .await; + } + } + + struct State { + version: u64, + wakers: Vec, + } + + impl State { + pub(crate) fn new() -> Self { + Self { + version: 1, + wakers: Vec::new(), + } + } + + pub(crate) fn poll_update( + &mut self, + observed_version: &mut u64, + waker_key: &mut Option, + cx: &Context<'_>, + ) -> Poll> { + if self.version == 0 { + *waker_key = None; + Poll::Ready(None) + } else if *observed_version < self.version { + *waker_key = None; + *observed_version = self.version; + Poll::Ready(Some(())) + } else { + self.wakers.push(cx.waker().clone()); + *waker_key = Some(self.wakers.len()); + Poll::Pending + } + } + + pub(crate) fn set_version(&mut self, version: u64) { + self.version = version; + for waker in self.wakers.drain(..) { + waker.wake(); + } + } + } +} + #[test] fn spawn_during_shutdown() { static DID_SPAWN: AtomicBool = AtomicBool::new(false); diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index 4deaf120a95..083eecf2976 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -81,7 +81,7 @@ fn join_size() { let ready2 = future::ready(0i32); tokio::join!(ready1, ready2) }; - assert_eq!(mem::size_of_val(&fut), 40); + assert_eq!(mem::size_of_val(&fut), 48); } async fn non_cooperative_task(permits: Arc) -> usize {