Skip to content

Commit

Permalink
handle cfg attr in construct runtime + enforce generics
Browse files Browse the repository at this point in the history
  • Loading branch information
gui1117 committed Oct 24, 2024
1 parent 28acc06 commit eb3061c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// limitations under the License

use crate::construct_runtime::Pallet;
use core::str::FromStr;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::quote;

Expand All @@ -28,7 +29,8 @@ pub fn expand_outer_task(
let mut from_impls = Vec::new();
let mut task_variants = Vec::new();
let mut variant_names = Vec::new();
let mut task_paths = Vec::new();
let mut task_types = Vec::new();
let mut cfg_attrs = Vec::new();
for decl in pallet_decls {
if decl.find_part("Task").is_none() {
continue
Expand All @@ -37,18 +39,31 @@ pub fn expand_outer_task(
let variant_name = &decl.name;
let path = &decl.path;
let index = decl.index;
let instance = decl.instance.as_ref().map(|instance| quote!(, #path::#instance));
let task_type = quote!(#path::Task<#runtime_name #instance>);

let attr = decl.cfg_pattern.iter().fold(TokenStream2::new(), |acc, pattern| {
let attr = TokenStream2::from_str(&format!("#[cfg({})]", pattern.original()))
.expect("was successfully parsed before; qed");
quote! {
#acc
#attr
}
});

from_impls.push(quote! {
impl From<#path::Task<#runtime_name>> for RuntimeTask {
fn from(hr: #path::Task<#runtime_name>) -> Self {
#attr
impl From<#task_type> for RuntimeTask {
fn from(hr: #task_type) -> Self {
RuntimeTask::#variant_name(hr)
}
}

impl TryInto<#path::Task<#runtime_name>> for RuntimeTask {
#attr
impl TryInto<#task_type> for RuntimeTask {
type Error = ();

fn try_into(self) -> Result<#path::Task<#runtime_name>, Self::Error> {
fn try_into(self) -> Result<#task_type, Self::Error> {
match self {
RuntimeTask::#variant_name(hr) => Ok(hr),
_ => Err(()),
Expand All @@ -58,13 +73,16 @@ pub fn expand_outer_task(
});

task_variants.push(quote! {
#attr
#[codec(index = #index)]
#variant_name(#path::Task<#runtime_name>),
#variant_name(#task_type),
});

variant_names.push(quote!(#variant_name));

task_paths.push(quote!(#path::Task));
task_types.push(task_type);

cfg_attrs.push(attr);
}

let prelude = quote!(#scrate::traits::tasks::__private);
Expand All @@ -91,35 +109,50 @@ pub fn expand_outer_task(

fn is_valid(&self) -> bool {
match self {
#(RuntimeTask::#variant_names(val) => val.is_valid(),)*
#(
#cfg_attrs
RuntimeTask::#variant_names(val) => val.is_valid(),
)*
_ => unreachable!(#INCOMPLETE_MATCH_QED),
}
}

fn run(&self) -> Result<(), #scrate::traits::tasks::__private::DispatchError> {
match self {
#(RuntimeTask::#variant_names(val) => val.run(),)*
#(
#cfg_attrs
RuntimeTask::#variant_names(val) => val.run(),
)*
_ => unreachable!(#INCOMPLETE_MATCH_QED),
}
}

fn weight(&self) -> #scrate::pallet_prelude::Weight {
match self {
#(RuntimeTask::#variant_names(val) => val.weight(),)*
#(
#cfg_attrs
RuntimeTask::#variant_names(val) => val.weight(),
)*
_ => unreachable!(#INCOMPLETE_MATCH_QED),
}
}

fn task_index(&self) -> u32 {
match self {
#(RuntimeTask::#variant_names(val) => val.task_index(),)*
#(
#cfg_attrs
RuntimeTask::#variant_names(val) => val.task_index(),
)*
_ => unreachable!(#INCOMPLETE_MATCH_QED),
}
}

fn iter() -> Self::Enumeration {
let mut all_tasks = Vec::new();
#(all_tasks.extend(#task_paths::iter().map(RuntimeTask::from).collect::<Vec<_>>());)*
#(
#cfg_attrs
all_tasks.extend(<#task_types>::iter().map(RuntimeTask::from).collect::<Vec<_>>());
)*
all_tasks.into_iter()
}
}
Expand Down
3 changes: 3 additions & 0 deletions substrate/frame/support/procedural/src/pallet/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ impl Def {
if let Some(extra_constants) = &self.extra_constants {
instances.extend_from_slice(&extra_constants.instances[..]);
}
if let Some(task_enum) = &self.task_enum {
instances.push(task_enum.instance_usage.clone());
}

let mut errors = instances.into_iter().filter_map(|instances| {
if instances.has_instance == self.config.has_instance {
Expand Down
14 changes: 9 additions & 5 deletions substrate/frame/support/procedural/src/pallet/parse/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::assert_parse_error_matches;
#[cfg(test)]
use crate::pallet::parse::tests::simulate_manifest_dir;

use super::helper;
use derive_syn_parse::Parse;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
Expand Down Expand Up @@ -139,10 +140,11 @@ pub type PalletTaskEnumAttr = PalletTaskAttr<keywords::task_enum>;

/// Parsing for a manually-specified (or auto-generated) task enum, optionally including the
/// attached `#[pallet::task_enum]` attribute.
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct TaskEnumDef {
pub attr: Option<PalletTaskEnumAttr>,
pub item_enum: ItemEnum,
pub instance_usage: helper::InstanceUsage,
}

impl syn::parse::Parse for TaskEnumDef {
Expand All @@ -154,7 +156,9 @@ impl syn::parse::Parse for TaskEnumDef {
None => None,
};

Ok(TaskEnumDef { attr, item_enum })
let instance_usage = helper::check_type_def_gen(&item_enum.generics, item_enum.span())?;

Ok(TaskEnumDef { attr, item_enum, instance_usage })
}
}

Expand Down Expand Up @@ -881,7 +885,7 @@ fn test_parse_task_enum_def_non_task_name() {
simulate_manifest_dir("../../examples/basic", || {
parse2::<TaskEnumDef>(quote! {
#[pallet::task_enum]
pub enum Something {
pub enum Something<T> {
Foo
}
})
Expand All @@ -906,7 +910,7 @@ fn test_parse_task_enum_def_missing_attr_allowed() {
fn test_parse_task_enum_def_missing_attr_alternate_name_allowed() {
simulate_manifest_dir("../../examples/basic", || {
parse2::<TaskEnumDef>(quote! {
pub enum Foo {
pub enum Foo<T> {
Red,
}
})
Expand Down Expand Up @@ -936,7 +940,7 @@ fn test_parse_task_enum_def_wrong_item() {
assert_parse_error_matches!(
parse2::<TaskEnumDef>(quote! {
#[pallet::task_enum]
pub struct Something;
pub struct Something<T>;
}),
"expected `enum`"
);
Expand Down
5 changes: 5 additions & 0 deletions substrate/frame/support/test/tests/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ mod runtime {

#[runtime::pallet_index(1)]
pub type MyPallet = my_pallet;

#[runtime::pallet_index(2)]
pub type MyPallet2 = my_pallet<Instance2>;
}

// NOTE: Needed for derive_impl expansion
Expand All @@ -82,6 +85,8 @@ impl frame_system::Config for Runtime {

impl my_pallet::Config for Runtime {}

impl my_pallet::Config<frame_support::instances::Instance2> for Runtime {}

fn new_test_ext() -> sp_io::TestExternalities {
use sp_runtime::BuildStorage;

Expand Down

0 comments on commit eb3061c

Please sign in to comment.