Skip to content

Commit

Permalink
refactor(test): execute all #[rustup_macros::unit_test]s within a `…
Browse files Browse the repository at this point in the history
…tokio` context
  • Loading branch information
rami3l committed Jun 14, 2024
1 parent 3ea4355 commit 0aa7a2a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 82 deletions.
107 changes: 41 additions & 66 deletions rustup-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ pub fn integration_test(
.into()
}

/// Custom wrapper macro around `#[test]` and `#[tokio::test]` for unit tests.
/// Custom wrapper macro around `#[tokio::test]` for unit tests.
///
/// Calls `rustup::test::before_test()` before the test body, and
/// `rustup::test::after_test()` after, even in the event of an unwinding panic.
/// For async functions calls the async variants of these functions.
///
/// This wrapper makes the underlying test function async even if it's sync in nature.
/// This ensures that a [`tokio`] runtime is always present during tests,
/// making it easier to setup [`tracing`] subscribers
/// (e.g. [`opentelemetry_otlp::OtlpTracePipeline`] always requires a [`tokio`] runtime to be
/// installed).
#[proc_macro_attribute]
pub fn unit_test(
args: proc_macro::TokenStream,
Expand Down Expand Up @@ -77,74 +82,44 @@ pub fn unit_test(
.into()
}

// False positive from clippy :/
#[allow(clippy::redundant_clone)]
fn test_inner(mod_path: String, mut input: ItemFn) -> syn::Result<TokenStream> {
if input.sig.asyncness.is_some() {
let before_ident = format!("{}::before_test_async", mod_path);
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
let after_ident = format!("{}::after_test_async", mod_path);
let after_ident = syn::parse_str::<Expr>(&after_ident)?;

let inner = input.block;
let name = input.sig.ident.clone();
let new_block: Block = parse_quote! {
{
#before_ident().await;
// Define a function with same name we can instrument inside the
// tracing enablement logic.
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
async fn #name() { #inner }
// Thunk through a new thread to permit catching the panic
// without grabbing the entire state machine defined by the
// outer test function.
let result = ::std::panic::catch_unwind(||{
let handle = tokio::runtime::Handle::current().clone();
::std::thread::spawn(move || handle.block_on(#name())).join().unwrap()
});
#after_ident().await;
match result {
Ok(result) => result,
Err(err) => ::std::panic::resume_unwind(err)
}
}
};
// Make the test function async even if it's sync.
input.sig.asyncness.get_or_insert_with(Default::default);

input.block = Box::new(new_block);
let before_ident = format!("{}::before_test_async", mod_path);
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
let after_ident = format!("{}::after_test_async", mod_path);
let after_ident = syn::parse_str::<Expr>(&after_ident)?;

Ok(quote! {
let inner = input.block;
let name = input.sig.ident.clone();
let new_block: Block = parse_quote! {
{
#before_ident().await;
// Define a function with same name we can instrument inside the
// tracing enablement logic.
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
#[::tokio::test(flavor = "multi_thread", worker_threads = 1)]
#input
})
} else {
let before_ident = format!("{}::before_test", mod_path);
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
let after_ident = format!("{}::after_test", mod_path);
let after_ident = syn::parse_str::<Expr>(&after_ident)?;

let inner = input.block;
let name = input.sig.ident.clone();
let new_block: Block = parse_quote! {
{
#before_ident();
// Define a function with same name we can instrument inside the
// tracing enablement logic.
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
fn #name() { #inner }
let result = ::std::panic::catch_unwind(#name);
#after_ident();
match result {
Ok(result) => result,
Err(err) => ::std::panic::resume_unwind(err)
}
async fn #name() { #inner }
// Thunk through a new thread to permit catching the panic
// without grabbing the entire state machine defined by the
// outer test function.
let result = ::std::panic::catch_unwind(||{
let handle = tokio::runtime::Handle::current().clone();
::std::thread::spawn(move || handle.block_on(#name())).join().unwrap()
});
#after_ident().await;
match result {
Ok(result) => result,
Err(err) => ::std::panic::resume_unwind(err)
}
};
}
};

input.block = Box::new(new_block);
Ok(quote! {
#[::std::prelude::v1::test]
#input
})
}
input.block = Box::new(new_block);

Ok(quote! {
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
#[::tokio::test(flavor = "multi_thread", worker_threads = 1)]
#input
})
}
16 changes: 0 additions & 16 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,29 +277,13 @@ static TRACER: Lazy<opentelemetry_sdk::trace::Tracer> = Lazy::new(|| {
tracer
});

pub fn before_test() {
#[cfg(feature = "otel")]
{
Lazy::force(&TRACER);
}
}

pub async fn before_test_async() {
#[cfg(feature = "otel")]
{
Lazy::force(&TRACER);
}
}

pub fn after_test() {
#[cfg(feature = "otel")]
{
let handle = TRACE_RUNTIME.handle();
let _guard = handle.enter();
TRACER.provider().map(|p| p.force_flush());
}
}

pub async fn after_test_async() {
#[cfg(feature = "otel")]
{
Expand Down

0 comments on commit 0aa7a2a

Please sign in to comment.