diff --git a/near-plugins-derive/src/access_controllable.rs b/near-plugins-derive/src/access_controllable.rs index 2088ebb..8b1e9c0 100644 --- a/near-plugins-derive/src/access_controllable.rs +++ b/near-plugins-derive/src/access_controllable.rs @@ -1,4 +1,5 @@ use crate::access_control_role::new_bitflags_type_ident; +use crate::utils; use crate::utils::{cratename, is_near_bindgen_wrapped_or_marshall}; use darling::FromMeta; use proc_macro::TokenStream; @@ -535,14 +536,7 @@ pub fn access_control_any(attrs: TokenStream, item: TokenStream) -> TokenStream return item; } - let ItemFn { - attrs, - vis, - sig, - block, - } = input; - let function_name = sig.ident.to_string(); - let stmts = &block.stmts; + let function_name = input.sig.ident.to_string(); let macro_args = match MacroArgsAny::from_list(&attr_args) { Ok(args) => args, @@ -569,12 +563,5 @@ pub fn access_control_any(attrs: TokenStream, item: TokenStream) -> TokenStream } }; - // https://stackoverflow.com/a/66851407 - quote! { - #(#attrs)* #vis #sig { - #acl_check - #(#stmts)* - } - } - .into() + utils::add_extra_code_to_fn(&input, acl_check) } diff --git a/near-plugins-derive/src/ownable.rs b/near-plugins-derive/src/ownable.rs index a5e01af..f8dc5a9 100644 --- a/near-plugins-derive/src/ownable.rs +++ b/near-plugins-derive/src/ownable.rs @@ -1,3 +1,4 @@ +use crate::utils; use crate::utils::{cratename, is_near_bindgen_wrapped_or_marshall}; use darling::FromDeriveInput; use proc_macro::{self, TokenStream}; @@ -98,18 +99,9 @@ pub fn only(attrs: TokenStream, item: TokenStream) -> TokenStream { } } - let ItemFn { - attrs, - vis, - sig, - block, - } = input; - let stmts = &block.stmts; - let owner_check = match (contains_self, contains_owner) { (true, true) => quote! { - let __predecessor_account_id = ::near_sdk::env::predecessor_account_id(); - if self.owner_get() != Some(__predecessor_account_id) { + if !self.owner_is() { ::near_sdk::assert_self(); } }, @@ -124,12 +116,5 @@ pub fn only(attrs: TokenStream, item: TokenStream) -> TokenStream { } }; - // https://stackoverflow.com/a/66851407 - quote! { - #(#attrs)* #vis #sig { - #owner_check - #(#stmts)* - } - } - .into() + utils::add_extra_code_to_fn(&input, owner_check) } diff --git a/near-plugins-derive/src/pausable.rs b/near-plugins-derive/src/pausable.rs index 5ad5dc4..060e005 100644 --- a/near-plugins-derive/src/pausable.rs +++ b/near-plugins-derive/src/pausable.rs @@ -1,3 +1,4 @@ +use crate::utils; use crate::utils::{cratename, is_near_bindgen_wrapped_or_marshall}; use darling::{FromDeriveInput, FromMeta}; use proc_macro::{self, TokenStream}; @@ -120,40 +121,9 @@ pub fn pause(attrs: TokenStream, item: TokenStream) -> TokenStream { let attr_args = parse_macro_input!(attrs as AttributeArgs); let args = PauseArgs::from_list(&attr_args).expect("Invalid arguments"); - let ItemFn { - attrs, - vis, - sig, - block, - } = input; - let stmts = &block.stmts; + let fn_name = args.name.unwrap_or_else(|| input.sig.ident.to_string()); - let fn_name = args.name.unwrap_or_else(|| sig.ident.to_string()); - - let self_condition = if args.except._self { - quote!( - if ::near_sdk::env::predecessor_account_id() == ::near_sdk::env::current_account_id() { - check_paused = false; - } - ) - } else { - quote!() - }; - - let owner_condition = if args.except.owner { - quote!( - if Some(::near_sdk::env::predecessor_account_id()) == self.owner_get() { - check_paused = false; - } - ) - } else { - quote!() - }; - - let bypass_condition = quote!( - #self_condition - #owner_condition - ); + let bypass_condition = get_bypass_condition(&args.except); let check_pause = quote!( let mut check_paused = true; @@ -163,15 +133,7 @@ pub fn pause(attrs: TokenStream, item: TokenStream) -> TokenStream { } ); - // https://stackoverflow.com/a/66851407 - let result = quote! { - #(#attrs)* #vis #sig { - #check_pause - #(#stmts)* - } - }; - - result.into() + utils::add_extra_code_to_fn(&input, check_pause) } #[derive(Debug, FromMeta)] @@ -191,17 +153,23 @@ pub fn if_paused(attrs: TokenStream, item: TokenStream) -> TokenStream { let attr_args = parse_macro_input!(attrs as AttributeArgs); let args = IfPausedArgs::from_list(&attr_args).expect("Invalid arguments"); - let ItemFn { - attrs, - vis, - sig, - block, - } = input; - let stmts = &block.stmts; - let fn_name = args.name; - let self_condition = if args.except._self { + let bypass_condition = get_bypass_condition(&args.except); + + let check_pause = quote!( + let mut check_paused = true; + #bypass_condition + if check_paused { + assert!(self.pa_is_paused(#fn_name.to_string()), "Pausable: Method must be paused"); + } + ); + + utils::add_extra_code_to_fn(&input, check_pause) +} + +fn get_bypass_condition(args: &ExceptSubArgs) -> proc_macro2::TokenStream { + let self_condition = if args._self { quote!( if ::near_sdk::env::predecessor_account_id() == ::near_sdk::env::current_account_id() { check_paused = false; @@ -211,7 +179,7 @@ pub fn if_paused(attrs: TokenStream, item: TokenStream) -> TokenStream { quote!() }; - let owner_condition = if args.except.owner { + let owner_condition = if args.owner { quote!( if Some(::near_sdk::env::predecessor_account_id()) == self.owner_get() { check_paused = false; @@ -221,26 +189,8 @@ pub fn if_paused(attrs: TokenStream, item: TokenStream) -> TokenStream { quote!() }; - let bypass_condition = quote!( + quote!( #self_condition #owner_condition - ); - - let check_pause = quote!( - let mut check_paused = true; - #bypass_condition - if check_paused { - assert!(self.pa_is_paused(#fn_name.to_string()), "Pausable: Method must be paused"); - } - ); - - // https://stackoverflow.com/a/66851407 - let result = quote! { - #(#attrs)* #vis #sig { - #check_pause - #(#stmts)* - } - }; - - result.into() + ) } diff --git a/near-plugins-derive/src/utils.rs b/near-plugins-derive/src/utils.rs index ce0b7f8..5a24f78 100644 --- a/near-plugins-derive/src/utils.rs +++ b/near-plugins-derive/src/utils.rs @@ -92,3 +92,25 @@ pub(crate) fn cratename() -> Ident { Span::call_site(), ) } + +pub(crate) fn add_extra_code_to_fn( + fn_code: &ItemFn, + extra_code: proc_macro2::TokenStream, +) -> proc_macro::TokenStream { + let ItemFn { + attrs, + vis, + sig, + block, + } = fn_code; + let stmts = &block.stmts; + + // https://stackoverflow.com/a/66851407 + quote::quote! { + #(#attrs)* #vis #sig { + #extra_code + #(#stmts)* + } + } + .into() +}