From 40da6b9d40235909c23c96c7747a13ba9820e59d Mon Sep 17 00:00:00 2001 From: 191220029 <522023330025@smail.nju.edu.cn> Date: Tue, 12 Nov 2024 10:22:31 +0800 Subject: [PATCH] Parser: `auto_node` macro --- Cargo.toml | 10 ++- derive/Cargo.toml | 17 ++++ derive/src/auto_node.rs | 180 +++++++++++++++++++++++++++++++++++++++ derive/src/lib.rs | 40 +++++++++ examples/auto_node.rs | 38 +++++++++ src/lib.rs | 4 + src/node/default_node.rs | 2 +- src/node/node.rs | 2 +- 8 files changed, 289 insertions(+), 4 deletions(-) create mode 100644 derive/Cargo.toml create mode 100644 derive/src/auto_node.rs create mode 100644 derive/src/lib.rs create mode 100644 examples/auto_node.rs diff --git a/Cargo.toml b/Cargo.toml index 336c341..6b87308 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,13 +10,14 @@ repository = "https://github.com/open-rust-initiative/dagrs" keywords = ["DAG", "task", "async", "parallel", "concurrent"] [workspace] -members = ["."] +members = [".", "derive"] [dependencies] tokio = { version = "1.28", features = ["rt", "sync", "rt-multi-thread"] } log = "0.4" env_logger = "0.10.1" async-trait = "0.1.83" +derive = { path = "derive", optional = true } [dev-dependencies] simplelog = "0.12" @@ -24,5 +25,10 @@ criterion = { version = "0.5.1", features = ["html_reports"] } [target.'cfg(unix)'.dev-dependencies] - [features] +default = ["derive"] +derive = ["derive/derive"] + +[[example]] +name = "auto_node" +required-features = ["derive"] \ No newline at end of file diff --git a/derive/Cargo.toml b/derive/Cargo.toml new file mode 100644 index 0000000..c5feb40 --- /dev/null +++ b/derive/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "derive" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +proc-macro2= "1.0" + +[lib] +proc-macro = true + +[features] +default = ["derive"] +derive = [] \ No newline at end of file diff --git a/derive/src/auto_node.rs b/derive/src/auto_node.rs new file mode 100644 index 0000000..91e399e --- /dev/null +++ b/derive/src/auto_node.rs @@ -0,0 +1,180 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::Parser; +use syn::{parse, parse_macro_input, Field, Generics, Ident, ItemStruct}; + +/// Generate fields & implements of `Node` trait. +/// +/// Step 1: generate fields (`id`, `name`, `input_channel`, `output_channel`, `action`) +/// +/// Step 2: generates methods for `Node` implementation. +/// +/// Step 3: append the generated fields to the input struct. +/// +/// Step 4: return tokens of the input struct & the generated methods. +pub(crate) fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream { + let mut item_struct = parse_macro_input!(input as ItemStruct); + let _ = parse_macro_input!(args as parse::Nothing); + + let generics = &item_struct.generics; + + let field_id = syn::Field::parse_named + .parse2(quote! { + id: dagrs::NodeId + }) + .unwrap(); + + let field_name = syn::Field::parse_named + .parse2(quote! { + name: String + }) + .unwrap(); + + let field_in_channels = syn::Field::parse_named + .parse2(quote! { + input_channels: dagrs::InChannels + }) + .unwrap(); + + let field_out_channels = syn::Field::parse_named + .parse2(quote! { + output_channels: dagrs::OutChannels + }) + .unwrap(); + + let field_action = syn::Field::parse_named + .parse2(quote! { + action: Box + }) + .unwrap(); + + let auto_impl = auto_impl_node( + &item_struct.ident, + generics, + &field_id, + &field_name, + &field_in_channels, + &field_out_channels, + &field_action, + ); + + match item_struct.fields { + syn::Fields::Named(ref mut fields) => { + fields.named.push(field_id); + fields.named.push(field_name); + fields.named.push(field_in_channels); + fields.named.push(field_out_channels); + fields.named.push(field_action); + } + syn::Fields::Unit => { + item_struct.fields = syn::Fields::Named(syn::FieldsNamed { + named: [ + field_id, + field_name, + field_in_channels, + field_out_channels, + field_action, + ] + .into_iter() + .collect(), + brace_token: Default::default(), + }); + } + _ => { + return syn::Error::new_spanned( + item_struct.ident, + "`auto_node` macro can only be annotated on named struct or unit struct.", + ) + .into_compile_error() + .into() + } + }; + + return quote! { + #item_struct + #auto_impl + } + .into(); +} + +fn auto_impl_node( + struct_ident: &Ident, + generics: &Generics, + field_id: &Field, + field_name: &Field, + field_in_channels: &Field, + field_out_channels: &Field, + field_action: &Field, +) -> proc_macro2::TokenStream { + let mut impl_tokens = proc_macro2::TokenStream::new(); + impl_tokens.extend([ + impl_id(field_id), + impl_name(field_name), + impl_in_channels(field_in_channels), + impl_out_channels(field_out_channels), + impl_run(field_action, field_in_channels, field_out_channels), + ]); + + quote::quote!( + impl #generics dagrs::Node for #struct_ident #generics { + #impl_tokens + } + unsafe impl #generics Send for #struct_ident #generics{} + unsafe impl #generics Sync for #struct_ident #generics{} + ) +} + +fn impl_id(field: &Field) -> proc_macro2::TokenStream { + let ident = &field.ident; + quote::quote!( + fn id(&self) -> dagrs::NodeId { + self.#ident + } + ) +} + +fn impl_name(field: &Field) -> proc_macro2::TokenStream { + let ident = &field.ident; + quote::quote!( + fn name(&self) -> dagrs::NodeName { + self.#ident.clone() + } + ) +} + +fn impl_in_channels(field: &Field) -> proc_macro2::TokenStream { + let ident = &field.ident; + quote::quote!( + fn input_channels(&mut self) -> &mut dagrs::InChannels { + &mut self.#ident + } + ) +} + +fn impl_out_channels(field: &Field) -> proc_macro2::TokenStream { + let ident = &field.ident; + quote::quote!( + fn output_channels(&mut self) -> &mut dagrs::OutChannels { + &mut self.#ident + } + ) +} + +fn impl_run( + field: &Field, + field_in_channels: &Field, + field_out_channels: &Field, +) -> proc_macro2::TokenStream { + let ident = &field.ident; + let in_channels_ident = &field_in_channels.ident; + let out_channels_ident = &field_out_channels.ident; + quote::quote!( + fn run(&mut self, env: std::sync::Arc) -> dagrs::Output { + tokio::runtime::Runtime::new().unwrap().block_on(async { + self.#ident + .run(&mut self.#in_channels_ident, &self.#out_channels_ident, env) + .await + }) + } + ) +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs new file mode 100644 index 0000000..7f928f6 --- /dev/null +++ b/derive/src/lib.rs @@ -0,0 +1,40 @@ +use proc_macro::TokenStream; + +#[cfg(feature = "derive")] +mod auto_node; + +/// [`auto_node`] is a macro that may be used when customizing nodes. It can only be +/// marked on named struct or unit struct. +/// +/// The macro [`auto_node`] generates essential fields and implementation of traits for +/// structs intended to represent `Node` in **Dagrs**. +/// By applying this macro to a struct, it appends fields including `id: dagrs::NodeId`, +/// `name: dagrs::NodeName`, `input_channels: dagrs::InChannels`, `output_channels: dagrs::OutChannels`, +/// and `action: dagrs::Action`, and implements the required `dagrs::Node` trait. +/// +/// ## Example +/// - Mark `auto_node` on a struct with customized fields. +/// ```ignore +/// use dagrs::auto_node; +/// #[auto_node] +/// struct MyNode {/*Put your customized fields here.*/} +/// ``` +/// +/// - Mark `auto_node` on a struct with generic & lifetime params. +/// ```ignore +/// use dagrs::auto_node; +/// #[auto_node] +/// struct MyNode {/*Put your customized fields here.*/} +/// ``` +/// - Mark `auto_node` on a unit struct. +/// ```ignore +/// use dagrs::auto_node; +/// #[auto_node] +/// struct MyNode() +/// ``` +#[cfg(feature = "derive")] +#[proc_macro_attribute] +pub fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream { + use crate::auto_node::auto_node; + auto_node(args, input).into() +} diff --git a/examples/auto_node.rs b/examples/auto_node.rs new file mode 100644 index 0000000..a6653f4 --- /dev/null +++ b/examples/auto_node.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; + +use dagrs::{auto_node, EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels}; + +#[auto_node] +struct MyNode {/*Put customized fields here.*/} + +#[auto_node] +struct _MyNodeGeneric { + my_field: Vec, + my_name: &'a str, +} + +#[auto_node] +struct _MyUnitNode; + +fn main() { + let mut node_table = NodeTable::default(); + + let node_name = "auto_node".to_string(); + + let mut s = MyNode { + id: node_table.alloc_id_for(&node_name), + name: node_name.clone(), + input_channels: InChannels::default(), + output_channels: OutChannels::default(), + action: Box::new(EmptyAction), + }; + + assert_eq!(&s.id(), node_table.get(&node_name).unwrap()); + assert_eq!(&s.name(), &node_name); + + let output = s.run(Arc::new(EnvVar::new(NodeTable::default()))); + match output { + dagrs::Output::Out(content) => assert!(content.is_none()), + _ => panic!(), + } +} diff --git a/src/lib.rs b/src/lib.rs index 4116d78..7e5bb1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,4 +12,8 @@ pub use node::{ default_node::DefaultNode, node::*, }; +pub use tokio; pub use utils::{env::EnvVar, output::Output}; + +#[cfg(feature = "derive")] +pub use derive::*; diff --git a/src/node/default_node.rs b/src/node/default_node.rs index 6e483fe..929a695 100644 --- a/src/node/default_node.rs +++ b/src/node/default_node.rs @@ -53,7 +53,7 @@ pub struct DefaultNode { impl Node for DefaultNode { fn id(&self) -> NodeId { - self.id.clone() + self.id } fn name(&self) -> NodeName { diff --git a/src/node/node.rs b/src/node/node.rs index 405fb11..9188156 100644 --- a/src/node/node.rs +++ b/src/node/node.rs @@ -29,7 +29,7 @@ pub trait Node: Send + Sync { fn run(&mut self, env: Arc) -> Output; } -#[derive(Debug, Hash, PartialEq, Eq, Clone)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub struct NodeId(pub(crate) usize); pub type NodeName = String;