From 0aa7a2a0511e7d3d731c77cb551831df84009ec9 Mon Sep 17 00:00:00 2001 From: rami3l Date: Wed, 12 Jun 2024 09:48:29 +0800 Subject: [PATCH] refactor(test): execute all `#[rustup_macros::unit_test]`s within a `tokio` context --- rustup-macros/src/lib.rs | 107 +++++++++++++++------------------------ src/test.rs | 16 ------ 2 files changed, 41 insertions(+), 82 deletions(-) diff --git a/rustup-macros/src/lib.rs b/rustup-macros/src/lib.rs index 5ccb3ae5cf..23818dee78 100644 --- a/rustup-macros/src/lib.rs +++ b/rustup-macros/src/lib.rs @@ -39,11 +39,16 @@ pub fn integration_test( .into() } -/// Custom wrapper macro around `#[test]` and `#[tokio::test]` for unit tests. +/// Custom wrapper macro around `#[tokio::test]` for unit tests. /// /// Calls `rustup::test::before_test()` before the test body, and /// `rustup::test::after_test()` after, even in the event of an unwinding panic. -/// For async functions calls the async variants of these functions. +/// +/// This wrapper makes the underlying test function async even if it's sync in nature. +/// This ensures that a [`tokio`] runtime is always present during tests, +/// making it easier to setup [`tracing`] subscribers +/// (e.g. [`opentelemetry_otlp::OtlpTracePipeline`] always requires a [`tokio`] runtime to be +/// installed). #[proc_macro_attribute] pub fn unit_test( args: proc_macro::TokenStream, @@ -77,74 +82,44 @@ pub fn unit_test( .into() } -// False positive from clippy :/ -#[allow(clippy::redundant_clone)] fn test_inner(mod_path: String, mut input: ItemFn) -> syn::Result { - if input.sig.asyncness.is_some() { - let before_ident = format!("{}::before_test_async", mod_path); - let before_ident = syn::parse_str::(&before_ident)?; - let after_ident = format!("{}::after_test_async", mod_path); - let after_ident = syn::parse_str::(&after_ident)?; - - let inner = input.block; - let name = input.sig.ident.clone(); - let new_block: Block = parse_quote! { - { - #before_ident().await; - // Define a function with same name we can instrument inside the - // tracing enablement logic. - #[cfg_attr(feature = "otel", tracing::instrument(skip_all))] - async fn #name() { #inner } - // Thunk through a new thread to permit catching the panic - // without grabbing the entire state machine defined by the - // outer test function. - let result = ::std::panic::catch_unwind(||{ - let handle = tokio::runtime::Handle::current().clone(); - ::std::thread::spawn(move || handle.block_on(#name())).join().unwrap() - }); - #after_ident().await; - match result { - Ok(result) => result, - Err(err) => ::std::panic::resume_unwind(err) - } - } - }; + // Make the test function async even if it's sync. + input.sig.asyncness.get_or_insert_with(Default::default); - input.block = Box::new(new_block); + let before_ident = format!("{}::before_test_async", mod_path); + let before_ident = syn::parse_str::(&before_ident)?; + let after_ident = format!("{}::after_test_async", mod_path); + let after_ident = syn::parse_str::(&after_ident)?; - Ok(quote! { + let inner = input.block; + let name = input.sig.ident.clone(); + let new_block: Block = parse_quote! { + { + #before_ident().await; + // Define a function with same name we can instrument inside the + // tracing enablement logic. #[cfg_attr(feature = "otel", tracing::instrument(skip_all))] - #[::tokio::test(flavor = "multi_thread", worker_threads = 1)] - #input - }) - } else { - let before_ident = format!("{}::before_test", mod_path); - let before_ident = syn::parse_str::(&before_ident)?; - let after_ident = format!("{}::after_test", mod_path); - let after_ident = syn::parse_str::(&after_ident)?; - - let inner = input.block; - let name = input.sig.ident.clone(); - let new_block: Block = parse_quote! { - { - #before_ident(); - // Define a function with same name we can instrument inside the - // tracing enablement logic. - #[cfg_attr(feature = "otel", tracing::instrument(skip_all))] - fn #name() { #inner } - let result = ::std::panic::catch_unwind(#name); - #after_ident(); - match result { - Ok(result) => result, - Err(err) => ::std::panic::resume_unwind(err) - } + async fn #name() { #inner } + // Thunk through a new thread to permit catching the panic + // without grabbing the entire state machine defined by the + // outer test function. + let result = ::std::panic::catch_unwind(||{ + let handle = tokio::runtime::Handle::current().clone(); + ::std::thread::spawn(move || handle.block_on(#name())).join().unwrap() + }); + #after_ident().await; + match result { + Ok(result) => result, + Err(err) => ::std::panic::resume_unwind(err) } - }; + } + }; - input.block = Box::new(new_block); - Ok(quote! { - #[::std::prelude::v1::test] - #input - }) - } + input.block = Box::new(new_block); + + Ok(quote! { + #[cfg_attr(feature = "otel", tracing::instrument(skip_all))] + #[::tokio::test(flavor = "multi_thread", worker_threads = 1)] + #input + }) } diff --git a/src/test.rs b/src/test.rs index 617b765cea..67f1aedd9f 100644 --- a/src/test.rs +++ b/src/test.rs @@ -277,13 +277,6 @@ static TRACER: Lazy = Lazy::new(|| { tracer }); -pub fn before_test() { - #[cfg(feature = "otel")] - { - Lazy::force(&TRACER); - } -} - pub async fn before_test_async() { #[cfg(feature = "otel")] { @@ -291,15 +284,6 @@ pub async fn before_test_async() { } } -pub fn after_test() { - #[cfg(feature = "otel")] - { - let handle = TRACE_RUNTIME.handle(); - let _guard = handle.enter(); - TRACER.provider().map(|p| p.force_flush()); - } -} - pub async fn after_test_async() { #[cfg(feature = "otel")] {