diff --git a/tokio/src/macros/join.rs b/tokio/src/macros/join.rs index cd89da1cc02..8c6ad545d09 100644 --- a/tokio/src/macros/join.rs +++ b/tokio/src/macros/join.rs @@ -113,7 +113,7 @@ doc! {macro_rules! join { (@ { // Type of rotator that controls which inner future to start with // when polling our output future. - rotator=$rotator:ty; + rotator_select=$rotator_select:ty; // One `_` for each branch in the `join!` macro. This is not used once // normalization is complete. @@ -126,7 +126,7 @@ doc! {macro_rules! join { $( ( $($skip:tt)* ) $e:expr, )* }) => {{ - use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; + use $crate::macros::support::{maybe_done, poll_fn, Future, Pin, RotatorSelect}; use $crate::macros::support::Poll::{Ready, Pending}; // Safety: nothing must be moved out of `futures`. This is to satisfy @@ -143,14 +143,14 @@ doc! {macro_rules! join { // let mut futures = &mut futures; - const COUNT: u32 = $($total)*; - // Each time the future created by poll_fn is polled, if not using biased mode, // a different future is polled first to ensure every future passed to join! // can make progress even if one of the futures consumes the whole budget. - let mut rotator = <$rotator>::default(); + let mut rotator = <$rotator_select as RotatorSelect>::Rotator::<{$($total)*}>::default(); poll_fn(move |cx| { + const COUNT: u32 = $($total)*; + let mut is_pending = false; let mut to_run = COUNT; @@ -205,17 +205,17 @@ doc! {macro_rules! join { // ===== Normalize ===== - (@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { - $crate::join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) + (@ { rotator_select=$rotator_select:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::join!(@{ rotator_select=$rotator_select; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) }; // ===== Entry point ===== ( biased; $($e:expr),+ $(,)?) => { - $crate::join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*) + $crate::join!(@{ rotator_select=$crate::macros::support::SelectBiased; () (0) } $($e,)*) }; ( $($e:expr),+ $(,)?) => { - $crate::join!(@{ rotator=$crate::macros::support::Rotator; () (0) } $($e,)*) + $crate::join!(@{ rotator_select=$crate::macros::support::SelectNormal; () (0) } $($e,)*) }; (biased;) => { async {}.await }; @@ -223,6 +223,30 @@ doc! {macro_rules! join { () => { async {}.await } }} +/// Helper trait to select which type of `Rotator` to use. +// We need this to allow specifying a const generic without +// colliding with caller const names due to macro hygiene. +pub trait RotatorSelect { + type Rotator: Default; +} + +/// Marker type indicating that the starting branch should +/// rotate each poll. +#[derive(Debug)] +pub struct SelectNormal; +/// Marker type indicating that the starting branch should +/// be the first declared branch each poll. +#[derive(Debug)] +pub struct SelectBiased; + +impl RotatorSelect for SelectNormal { + type Rotator = Rotator; +} + +impl RotatorSelect for SelectBiased { + type Rotator = BiasedRotator; +} + /// Rotates by one each [`Self::num_skip`] call up to COUNT - 1. #[derive(Default, Debug)] pub struct Rotator { diff --git a/tokio/src/macros/support.rs b/tokio/src/macros/support.rs index 213d85cde4e..0231354e18f 100644 --- a/tokio/src/macros/support.rs +++ b/tokio/src/macros/support.rs @@ -3,7 +3,7 @@ cfg_macros! { pub use std::future::poll_fn; - pub use crate::macros::join::{BiasedRotator, Rotator}; + pub use crate::macros::join::{BiasedRotator, Rotator, RotatorSelect, SelectNormal, SelectBiased}; #[doc(hidden)] pub fn thread_rng_n(n: u32) -> u32 { diff --git a/tokio/src/macros/try_join.rs b/tokio/src/macros/try_join.rs index 471b5ebb42d..c31854295f1 100644 --- a/tokio/src/macros/try_join.rs +++ b/tokio/src/macros/try_join.rs @@ -166,7 +166,7 @@ doc! {macro_rules! try_join { (@ { // Type of rotator that controls which inner future to start with // when polling our output future. - rotator=$rotator:ty; + rotator_select=$rotator_select:ty; // One `_` for each branch in the `try_join!` macro. This is not used once // normalization is complete. @@ -179,7 +179,7 @@ doc! {macro_rules! try_join { $( ( $($skip:tt)* ) $e:expr, )* }) => {{ - use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; + use $crate::macros::support::{maybe_done, poll_fn, Future, Pin, RotatorSelect}; use $crate::macros::support::Poll::{Ready, Pending}; // Safety: nothing must be moved out of `futures`. This is to satisfy @@ -196,14 +196,14 @@ doc! {macro_rules! try_join { // let mut futures = &mut futures; - const COUNT: u32 = $($total)*; - // Each time the future created by poll_fn is polled, if not using biased mode, // a different future is polled first to ensure every future passed to try_join! // can make progress even if one of the futures consumes the whole budget. - let mut rotator = <$rotator>::default(); + let mut rotator = <$rotator_select as RotatorSelect>::Rotator::<{$($total)*}>::default(); poll_fn(move |cx| { + const COUNT: u32 = $($total)*; + let mut is_pending = false; let mut to_run = COUNT; @@ -264,17 +264,17 @@ doc! {macro_rules! try_join { // ===== Normalize ===== - (@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { - $crate::try_join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) + (@ { rotator_select=$rotator_select:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::try_join!(@{ rotator_select=$rotator_select; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) }; // ===== Entry point ===== ( biased; $($e:expr),+ $(,)?) => { - $crate::try_join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*) + $crate::try_join!(@{ rotator_select=$crate::macros::support::SelectBiased; () (0) } $($e,)*) }; ( $($e:expr),+ $(,)?) => { - $crate::try_join!(@{ rotator=$crate::macros::support::Rotator; () (0) } $($e,)*) + $crate::try_join!(@{ rotator_select=$crate::macros::support::SelectNormal; () (0) } $($e,)*) }; (biased;) => { async { Ok(()) }.await }; diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index fd4fdae3a6a..4c6db26d8ae 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -234,3 +234,23 @@ async fn join_into_future() { tokio::join!(NotAFuture); } + +// Regression test for: https://github.com/tokio-rs/tokio/issues/7637 +// We want to make sure that the `const COUNT: u32` declaration +// inside the macro body doesn't leak to the caller to cause compiler failures +// or variable shadowing. +#[tokio::test] +async fn caller_names_const_count() { + let (tx, rx) = oneshot::channel::(); + + const COUNT: u32 = 2; + + let mut join = task::spawn(async { tokio::join!(async { tx.send(COUNT).unwrap() }) }); + assert_ready!(join.poll()); + + let res = rx.await.unwrap(); + + // This passing demonstrates that the const in the macro is + // not shadowing the caller-specified COUNT value + assert_eq!(2, res); +} diff --git a/tokio/tests/macros_try_join.rs b/tokio/tests/macros_try_join.rs index e68c3400ff3..03172ca2a2d 100644 --- a/tokio/tests/macros_try_join.rs +++ b/tokio/tests/macros_try_join.rs @@ -247,3 +247,23 @@ async fn empty_try_join() { assert_eq!(tokio::try_join!() as Result<_, ()>, Ok(())); assert_eq!(tokio::try_join!(biased;) as Result<_, ()>, Ok(())); } + +// Regression test for: https://github.com/tokio-rs/tokio/issues/7637 +// We want to make sure that the `const COUNT: u32` declaration +// inside the macro body doesn't leak to the caller to cause compiler failures +// or variable shadowing. +#[tokio::test] +async fn caller_names_const_count() { + let (tx, rx) = oneshot::channel::(); + + const COUNT: u32 = 2; + + let mut try_join = task::spawn(async { tokio::try_join!(async { tx.send(COUNT) }) }); + assert_ready!(try_join.poll()).unwrap(); + + let res = rx.await.unwrap(); + + // This passing demonstrates that the const in the macro is + // not shadowing the caller-specified COUNT value + assert_eq!(2, res); +}