Skip to content

Commit

Permalink
Parser: auto_node macro
Browse files Browse the repository at this point in the history
  • Loading branch information
191220029 authored and genedna committed Nov 13, 2024
1 parent f4d2466 commit a96fd4a
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 4 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@ repository = "https://github.com/open-rust-initiative/dagrs"
keywords = ["DAG", "task", "async", "parallel", "concurrent"]

[workspace]
members = ["."]
members = [".", "derive"]

[dependencies]
tokio = { version = "1.28", features = ["rt", "sync", "rt-multi-thread"] }
log = "0.4"
env_logger = "0.10.1"
async-trait = "0.1.83"
derive = { path = "derive", optional = true }

[dev-dependencies]
simplelog = "0.12"
criterion = { version = "0.5.1", features = ["html_reports"] }

[target.'cfg(unix)'.dev-dependencies]


[features]
default = ["derive"]
derive = ["derive/derive"]

[[example]]
name = "auto_node"
required-features = ["derive"]
17 changes: 17 additions & 0 deletions derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "derive"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"

[dependencies]
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
proc-macro2= "1.0"

[lib]
proc-macro = true

[features]
default = ["derive"]
derive = []
180 changes: 180 additions & 0 deletions derive/src/auto_node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser;
use syn::{parse, parse_macro_input, Field, Generics, Ident, ItemStruct};

/// Generate fields & implements of `Node` trait.
///
/// Step 1: generate fields (`id`, `name`, `input_channel`, `output_channel`, `action`)
///
/// Step 2: generates methods for `Node` implementation.
///
/// Step 3: append the generated fields to the input struct.
///
/// Step 4: return tokens of the input struct & the generated methods.
pub(crate) fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream {
let mut item_struct = parse_macro_input!(input as ItemStruct);
let _ = parse_macro_input!(args as parse::Nothing);

let generics = &item_struct.generics;

let field_id = syn::Field::parse_named
.parse2(quote! {
id: dagrs::NodeId
})
.unwrap();

let field_name = syn::Field::parse_named
.parse2(quote! {
name: String
})
.unwrap();

let field_in_channels = syn::Field::parse_named
.parse2(quote! {
input_channels: dagrs::InChannels
})
.unwrap();

let field_out_channels = syn::Field::parse_named
.parse2(quote! {
output_channels: dagrs::OutChannels
})
.unwrap();

let field_action = syn::Field::parse_named
.parse2(quote! {
action: Box<dyn dagrs::Action>
})
.unwrap();

let auto_impl = auto_impl_node(
&item_struct.ident,
generics,
&field_id,
&field_name,
&field_in_channels,
&field_out_channels,
&field_action,
);

match item_struct.fields {
syn::Fields::Named(ref mut fields) => {
fields.named.push(field_id);
fields.named.push(field_name);
fields.named.push(field_in_channels);
fields.named.push(field_out_channels);
fields.named.push(field_action);
}
syn::Fields::Unit => {
item_struct.fields = syn::Fields::Named(syn::FieldsNamed {
named: [
field_id,
field_name,
field_in_channels,
field_out_channels,
field_action,
]
.into_iter()
.collect(),
brace_token: Default::default(),
});
}
_ => {
return syn::Error::new_spanned(
item_struct.ident,
"`auto_node` macro can only be annotated on named struct or unit struct.",
)
.into_compile_error()
.into()
}
};

return quote! {
#item_struct
#auto_impl
}
.into();
}

fn auto_impl_node(
struct_ident: &Ident,
generics: &Generics,
field_id: &Field,
field_name: &Field,
field_in_channels: &Field,
field_out_channels: &Field,
field_action: &Field,
) -> proc_macro2::TokenStream {
let mut impl_tokens = proc_macro2::TokenStream::new();
impl_tokens.extend([
impl_id(field_id),
impl_name(field_name),
impl_in_channels(field_in_channels),
impl_out_channels(field_out_channels),
impl_run(field_action, field_in_channels, field_out_channels),
]);

quote::quote!(
impl #generics dagrs::Node for #struct_ident #generics {
#impl_tokens
}
unsafe impl #generics Send for #struct_ident #generics{}
unsafe impl #generics Sync for #struct_ident #generics{}
)
}

fn impl_id(field: &Field) -> proc_macro2::TokenStream {
let ident = &field.ident;
quote::quote!(
fn id(&self) -> dagrs::NodeId {
self.#ident
}
)
}

fn impl_name(field: &Field) -> proc_macro2::TokenStream {
let ident = &field.ident;
quote::quote!(
fn name(&self) -> dagrs::NodeName {
self.#ident.clone()
}
)
}

fn impl_in_channels(field: &Field) -> proc_macro2::TokenStream {
let ident = &field.ident;
quote::quote!(
fn input_channels(&mut self) -> &mut dagrs::InChannels {
&mut self.#ident
}
)
}

fn impl_out_channels(field: &Field) -> proc_macro2::TokenStream {
let ident = &field.ident;
quote::quote!(
fn output_channels(&mut self) -> &mut dagrs::OutChannels {
&mut self.#ident
}
)
}

fn impl_run(
field: &Field,
field_in_channels: &Field,
field_out_channels: &Field,
) -> proc_macro2::TokenStream {
let ident = &field.ident;
let in_channels_ident = &field_in_channels.ident;
let out_channels_ident = &field_out_channels.ident;
quote::quote!(
fn run(&mut self, env: std::sync::Arc<dagrs::EnvVar>) -> dagrs::Output {
tokio::runtime::Runtime::new().unwrap().block_on(async {
self.#ident
.run(&mut self.#in_channels_ident, &self.#out_channels_ident, env)
.await
})
}
)
}
40 changes: 40 additions & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use proc_macro::TokenStream;

#[cfg(feature = "derive")]
mod auto_node;

/// [`auto_node`] is a macro that may be used when customizing nodes. It can only be
/// marked on named struct or unit struct.
///
/// The macro [`auto_node`] generates essential fields and implementation of traits for
/// structs intended to represent `Node` in **Dagrs**.
/// By applying this macro to a struct, it appends fields including `id: dagrs::NodeId`,
/// `name: dagrs::NodeName`, `input_channels: dagrs::InChannels`, `output_channels: dagrs::OutChannels`,
/// and `action: dagrs::Action`, and implements the required `dagrs::Node` trait.
///
/// ## Example
/// - Mark `auto_node` on a struct with customized fields.
/// ```ignore
/// use dagrs::auto_node;
/// #[auto_node]
/// struct MyNode {/*Put your customized fields here.*/}
/// ```
///
/// - Mark `auto_node` on a struct with generic & lifetime params.
/// ```ignore
/// use dagrs::auto_node;
/// #[auto_node]
/// struct MyNode<T, 'a> {/*Put your customized fields here.*/}
/// ```
/// - Mark `auto_node` on a unit struct.
/// ```ignore
/// use dagrs::auto_node;
/// #[auto_node]
/// struct MyNode()
/// ```
#[cfg(feature = "derive")]
#[proc_macro_attribute]
pub fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream {
use crate::auto_node::auto_node;
auto_node(args, input).into()
}
38 changes: 38 additions & 0 deletions examples/auto_node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::sync::Arc;

use dagrs::{auto_node, EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels};

#[auto_node]
struct MyNode {/*Put customized fields here.*/}

#[auto_node]
struct _MyNodeGeneric<T, 'a> {
my_field: Vec<T>,
my_name: &'a str,
}

#[auto_node]
struct _MyUnitNode;

fn main() {
let mut node_table = NodeTable::default();

let node_name = "auto_node".to_string();

let mut s = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};

assert_eq!(&s.id(), node_table.get(&node_name).unwrap());
assert_eq!(&s.name(), &node_name);

let output = s.run(Arc::new(EnvVar::new(NodeTable::default())));
match output {
dagrs::Output::Out(content) => assert!(content.is_none()),
_ => panic!(),
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ pub use node::{
default_node::DefaultNode,
node::*,
};
pub use tokio;
pub use utils::{env::EnvVar, output::Output};

#[cfg(feature = "derive")]
pub use derive::*;
2 changes: 1 addition & 1 deletion src/node/default_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct DefaultNode {

impl Node for DefaultNode {
fn id(&self) -> NodeId {
self.id.clone()
self.id
}

fn name(&self) -> NodeName {
Expand Down
2 changes: 1 addition & 1 deletion src/node/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub trait Node: Send + Sync {
fn run(&mut self, env: Arc<EnvVar>) -> Output;
}

#[derive(Debug, Hash, PartialEq, Eq, Clone)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct NodeId(pub(crate) usize);

pub type NodeName = String;
Expand Down

0 comments on commit a96fd4a

Please sign in to comment.