diff --git a/faux_macros/src/lib.rs b/faux_macros/src/lib.rs index 5787558..3c32eb5 100644 --- a/faux_macros/src/lib.rs +++ b/faux_macros/src/lib.rs @@ -58,6 +58,7 @@ pub fn when(input: proc_macro::TokenStream) -> proc_macro::TokenStream { receiver, method, args, + turbofish, .. }) => { let when = quote::format_ident!("_when_{}", method); @@ -69,7 +70,8 @@ pub fn when(input: proc_macro::TokenStream) -> proc_macro::TokenStream { match args { Err(e) => e.write_errors().into(), - Ok(args) => TokenStream::from(quote!({ #receiver.#when().with_args((#(#args,)*)) })) + Ok(args) if args.is_empty() => { TokenStream::from(quote!({ #receiver.#when #turbofish() }))} + Ok(args) => { TokenStream::from(quote!({ #receiver.#when #turbofish().with_args((#(#args,)*)) }))} } } expr => darling::Error::custom("faux::when! only accepts arguments in the format of: `when!(receiver.method)` or `receiver.method(args...)`") diff --git a/faux_macros/src/methods/morphed.rs b/faux_macros/src/methods/morphed.rs index 0f3dcd6..d1d04f5 100644 --- a/faux_macros/src/methods/morphed.rs +++ b/faux_macros/src/methods/morphed.rs @@ -1,7 +1,7 @@ use crate::{methods::receiver::Receiver, self_type::SelfType}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use syn::{spanned::Spanned, PathArguments, Type, TypePath}; +use syn::{spanned::Spanned, Generics, PathArguments, Type, TypePath}; pub struct Signature<'a> { name: &'a syn::Ident, @@ -147,9 +147,16 @@ impl<'a> Signature<'a> { let name = &self.name; let args = &self.args; + let generics = self + .method_data + .as_ref() + .map(|method_data| method_data.generics.clone()); + + let maybe_generics = generic_types_only(generics); + let proxy = match self.trait_path { - None => quote! { <#real_ty>::#name }, - Some(path) => quote! { <#real_ty as #path>::#name }, + None => quote! { <#real_ty>::#name #maybe_generics }, + Some(path) => quote! { <#real_ty as #path>::#name #maybe_generics }, }; let real_self_arg = self.method_data.as_ref().map(|_| { @@ -210,7 +217,7 @@ impl<'a> Signature<'a> { quote! { unsafe { - match _maybe_faux_faux.call_stub(::#faux_ident, #fn_name, #args) { + match _maybe_faux_faux.call_stub(::#faux_ident #maybe_generics, #fn_name, #args) { std::result::Result::Ok(o) => o, std::result::Result::Err(e) => panic!("{}", e), } @@ -359,11 +366,13 @@ impl<'a> MethodData<'a> { let generics_where_clause = &generics.where_clause; + let maybe_generics = generic_types_only(Some(generics.clone())); + let when_method = syn::parse_quote! { pub fn #when_ident<'m #maybe_comma #generics_contents>(&'m mut self) -> faux::When<'m, #receiver_ty, (#(#arg_types),*), #output, faux::matcher::AnyInvocation> #generics_where_clause { match &mut self.0 { faux::MaybeFaux::Faux(_maybe_faux_faux) => faux::When::new( - ::#faux_ident, + ::#faux_ident #maybe_generics, #name_str, _maybe_faux_faux ), @@ -480,3 +489,15 @@ fn path_args_contains_self(path: &syn::Path, self_path: &syn::TypePath) -> bool } } } + +fn generic_types_only(generics: Option) -> TokenStream { + if let Some(mut g) = generics { + let type_params = g + .type_params_mut() + .into_iter() + .map(|type_param| type_param.ident.clone()); + quote! { :: < #(#type_params),* > } + } else { + quote! {} + } +} diff --git a/tests/generic_method_return.rs b/tests/generic_method_return.rs index d7e80f3..ffbf563 100644 --- a/tests/generic_method_return.rs +++ b/tests/generic_method_return.rs @@ -12,8 +12,6 @@ pub struct Foo {} #[faux::create] pub struct Bar {} - - #[faux::methods] impl Foo { pub fn foo(&self, _e: E) -> E { @@ -22,12 +20,20 @@ impl Foo { pub fn bar(&self, _e: E, _f: F) -> Result { todo!() } - pub fn baz(&self, _e: E) -> E where E: MyTrait { + pub fn baz(&self, _e: E) -> E + where + E: MyTrait, + { + todo!() + } + pub fn qux(&self) + where + E: MyTrait, + { todo!() } } - #[faux::create] struct AsyncFoo {} #[faux::methods] @@ -38,11 +44,26 @@ impl AsyncFoo { pub async fn bar(&self, _e: E, _f: F) -> Result { todo!() } - pub async fn baz(&self, _e: E) -> E where E: MyTrait { + pub async fn baz(&self, _e: E) -> E + where + E: MyTrait, + { + todo!() + } + pub async fn qux(&self) + where + E: MyTrait, + { todo!() } -} + pub async fn qux_with_arg(&self, _arg: u32) -> u32 + where + E: MyTrait, + { + todo!() + } +} #[test] fn generics() { @@ -57,6 +78,10 @@ fn generics() { let mut baz = Foo::faux(); faux::when!(baz.baz).then_return(Entity {}); assert_eq!(baz.baz(Entity {}), Entity {}); + + let mut qux = Foo::faux(); + faux::when!(qux.qux::()).then(|_| {}); + qux.qux::(); } #[test] @@ -69,9 +94,21 @@ fn generic_tests_async() { let mut baz = AsyncFoo::faux(); faux::when!(baz.baz).then_return(Entity {}); + + let mut qux = AsyncFoo::faux(); + faux::when!(qux.qux::()).then(|_| {}); + + let mut qux_with_arg = AsyncFoo::faux(); + faux::when!(qux_with_arg.qux_with_arg::()).then(|_| 100); + faux::when!(qux_with_arg.qux_with_arg::(42)).then(|_| 84); + faux::when!(qux_with_arg.qux_with_arg::(43)).then(|_| 86); futures::executor::block_on(async { assert_eq!(foo.foo(Entity {}).await, Entity {}); assert_eq!(bar.bar(Entity {}, Entity {}).await, Ok(Entity {})); assert_eq!(baz.baz(Entity {}).await, Entity {}); + qux.qux::().await; + assert_eq!(qux_with_arg.qux_with_arg::(42).await, 84); + assert_eq!(qux_with_arg.qux_with_arg::(43).await, 86); + assert_eq!(qux_with_arg.qux_with_arg::(50).await, 100); }); }