diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 0fbfc0e9af..bc2a31a482 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -4,9 +4,9 @@ use crate::{ attr_parsing::{parse_assignment_attribute, second}, with_position::{Position, WithPosition}, }; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type}; +use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type}; pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { let Attrs { state_ty } = attr; @@ -15,6 +15,7 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { let check_extractor_count = check_extractor_count(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); + let check_output_tuples = check_output_tuples(&item_fn); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); // If the function is generic, we can't reliably check its inputs or whether the future it @@ -72,6 +73,7 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { #item_fn #check_extractor_count #check_path_extractor + #check_output_tuples #check_output_impls_into_response #check_inputs_and_future_send } @@ -284,6 +286,142 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStr .collect::() } +///If the output is a tuple with 2 or more elements, +/// it checks with the following pattern, +/// first element => StatusCode || Parts || IntoResponseParts +///last element => IntoResponse +///other elements => IntoResponseParts +///the max numbers of IntoResponseParts(16) +fn check_output_tuples(item_fn: &ItemFn) -> Option { + //Extract tuple types + let elements = match &item_fn.sig.output { + ReturnType::Type(_, ty) => match &**ty { + Type::Tuple(tuple) => &tuple.elems, + _ => return None, + }, + _ => return None, + }; + + if elements.len() < 2 { + return None; + } + //Amount of IntoRequestParts + let mut parts_amount = 0; + + let token_stream = WithPosition::new(elements.iter()) + .enumerate() + .map(|(_idx, arg)| match &arg { + Position::First(ty) => { + let typename = extract_clean_typename(ty); + if typename.is_none() { + quote! {} + } else { + let typename = typename.unwrap(); + match &*typename.to_string() { + "Parts" => quote! {}, + "Response" => quote! {}, + "StatusCode" => { + quote! {} + } + _ => { + parts_amount += 1; + check_into_response_parts(ty) + } + } + } + } + Position::Last(ty) => check_into_response(ty), + Position::Middle(ty) => { + parts_amount += 1; + if parts_amount >= 16 { + let error_message = format!("Output Tuple cannot have more than 16 arguments."); + let error = syn::Error::new_spanned(&item_fn.sig.output, error_message) + .to_compile_error(); + error + } else { + //todo check Named IntoResponse like Json, and hint that it should be placed last. + check_into_response_parts(ty) + } + } + _ => quote! {}, + }) + .collect::(); + Some(token_stream) +} + +fn check_into_response(ty: &Type) -> TokenStream { + let (span, ty) = (ty.span(), ty.clone()); + + let check_fn = format_ident!("__axum_macros_check_into_response_check", span = span,); + + let call_check_fn = format_ident!("__axum_macros_check_into_response_call_check", span = span,); + + let call_check_fn_body = quote_spanned! {span=> + #check_fn(); + }; + + let from_request_bound = quote_spanned! {span=> + #ty: ::axum::response::IntoResponse + }; + quote::quote_spanned! {span=> + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #check_fn() + where + #from_request_bound, + {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #call_check_fn() + { + #call_check_fn_body + } + } +} + +fn check_into_response_parts(ty: &Type) -> TokenStream { + let (span, ty) = (ty.span(), ty.clone()); + + let check_fn = format_ident!("__axum_macros_check_into_response_parts_check", span = span,); + + let call_check_fn = format_ident!( + "__axum_macros_check_into_response_parts_call_check", + span = span, + ); + + let call_check_fn_body = quote_spanned! {span=> + #check_fn(); + }; + + let from_request_bound = quote_spanned! {span=> + #ty: ::axum::response::IntoResponseParts + }; + quote::quote_spanned! {span=> + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #check_fn() + where + #from_request_bound, + {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #call_check_fn() + { + #call_check_fn_body + } + } +} + fn check_input_order(item_fn: &ItemFn) -> Option { let types_that_consume_the_request = item_fn .sig @@ -355,14 +493,17 @@ fn check_input_order(item_fn: &ItemFn) -> Option { } } -fn request_consuming_type_name(ty: &Type) -> Option<&'static str> { +fn extract_clean_typename(ty: &Type) -> Option<&Ident> { let path = match ty { Type::Path(type_path) => &type_path.path, _ => return None, }; + path.segments.last().map(|p| &p.ident) +} - let ident = match path.segments.last() { - Some(path_segment) => &path_segment.ident, +fn request_consuming_type_name(ty: &Type) -> Option<&'static str> { + let ident = match extract_clean_typename(ty) { + Some(ident) => ident, None => return None, }; diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_tuple.rs b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.rs new file mode 100644 index 0000000000..fa5197a81d --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.rs @@ -0,0 +1,11 @@ + +#[axum::debug_handler] +async fn handler() -> ( + axum::http::StatusCode, + axum::Json<&'static str>, + axum::response::AppendHeaders<[( axum::http::HeaderName,&'static str); 1]>, +) { + panic!() +} + +fn main(){} \ No newline at end of file diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr new file mode 100644 index 0000000000..8c75c43d53 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr @@ -0,0 +1,18 @@ +error[E0277]: the trait bound `Json<&'static str>: IntoResponseParts` is not satisfied + --> tests/debug_handler/fail/wrong_return_tuple.rs:5:5 + | +5 | axum::Json<&'static str>, + | ^^^^^^^^^^^^^^^^^^^^^^^^ the trait `IntoResponseParts` is not implemented for `Json<&'static str>` + | + = help: the following other types implement trait `IntoResponseParts`: + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) + (T1, T2, T3, T4, T5, T6, T7, T8, T9) + and $N others + = help: see issue #48214 + = help: add `#![feature(trivial_bounds)]` to the crate attributes to enable