Skip to content

Commit

Permalink
auto_graph
Browse files Browse the repository at this point in the history
sign

Signed-off-by: A-Mavericks <[email protected]>

fmt

Signed-off-by: A-Mavericks <[email protected]>
  • Loading branch information
A-Mavericks authored and genedna committed Dec 19, 2024
1 parent 85685d9 commit 889b6de
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 16 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ log = "0.4"
env_logger = "0.10.1"
async-trait = "0.1.83"
derive = { path = "derive", optional = true }
proc-macro2 = "1.0"

[dev-dependencies]
simplelog = "0.12"
Expand All @@ -31,4 +32,4 @@ derive = ["derive/derive"]

[[example]]
name = "auto_node"
required-features = ["derive"]
required-features = ["derive"]
14 changes: 13 additions & 1 deletion derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use proc_macro::TokenStream;

#[cfg(feature = "derive")]
mod auto_node;
mod relay;

/// [`auto_node`] is a macro that may be used when customizing nodes. It can only be
/// marked on named struct or unit struct.
Expand Down Expand Up @@ -38,3 +38,15 @@ pub fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream {
use crate::auto_node::auto_node;
auto_node(args, input).into()
}

/// The [`dependencies!`] macro allows users to specify all task dependencies in an easy-to-understand
/// way. It will return the generated graph structure based on a set of defined dependencies
#[cfg(feature = "derive")]
#[proc_macro]
pub fn dependencies(input: TokenStream) -> TokenStream {
use relay::add_relay;
use relay::Relaies;
let relaies = syn::parse_macro_input!(input as Relaies);
let token = add_relay(relaies);
token.into()
}
106 changes: 106 additions & 0 deletions derive/src/relay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::collections::{HashMap, HashSet};

Check warning on line 1 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Check

unused import: `HashMap`

Check warning on line 1 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused import: `HashMap`

use proc_macro2::Ident;
use syn::{parse::Parse, Token};

/// Parses and processes a set of relay tasks and their successors, and generates a directed graph.
///
/// Step 1: Define the `Relay` struct with a task and its associated successors (other tasks that depend on it).
///
/// Step 2: Implement the `Parse` trait for `Relaies` to parse a sequence of task-successor pairs from input. This creates a vector of `Relay` objects.
///
/// Step 3: In `add_relay`, initialize a directed graph structure using `Graph` and a hash map to store edges between nodes.
///
/// Step 4: Iterate through each `Relay` and update the graph's edge list by adding nodes (tasks) and defining edges between tasks and their successors.
///
/// Step 5: Ensure that each task is only added once to the graph using a cache (`HashSet`) to avoid duplicates.
///
/// Step 6: Populate the edges of the graph with the previously processed data and return the graph.
///
/// This code provides the logic to dynamically build a graph based on parsed task relationships, where each task is a node and the successors define directed edges between nodes.
pub(crate) struct Relay {
pub(crate) task: Ident,
pub(crate) successors: Vec<Ident>,
}

pub(crate) struct Relaies(pub(crate) Vec<Relay>);

impl Parse for Relaies {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut relies = Vec::new();
loop {
let mut successors = Vec::new();
let task = input.parse::<Ident>()?;
input.parse::<syn::Token!(->)>()?;
while !input.peek(Token!(,)) && !input.is_empty() {
successors.push(input.parse::<Ident>()?);
}
let relay = Relay { task, successors };
relies.push(relay);
let _ = input.parse::<Token!(,)>();
if input.is_empty() {
break;
}
}
Ok(Self(relies))
}
}

pub(crate) fn add_relay(relaies: Relaies) -> proc_macro2::TokenStream {
let mut token = proc_macro2::TokenStream::new();
let mut cache: HashSet<Ident> = HashSet::new();
token.extend(quote::quote!(
use dagrs::Graph;
use dagrs::NodeId;
use std::collections::HashMap;
use std::collections::HashSet;
let mut edge: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
let mut graph = Graph::new();
));
for relay in relaies.0.iter() {
let task = relay.task.clone();
token.extend(quote::quote!(
let task_id = #task.id();
if(!edge.contains_key(&task_id)){
edge.insert(task_id, HashSet::new());
}
));
for successor in relay.successors.iter() {
token.extend(quote::quote!(
let successor_id = #successor.id();
edge.entry(task_id)
.or_insert_with(HashSet::new)
.insert(successor_id);
));
}
}
for relay in relaies.0.iter() {
let task = relay.task.clone();
if (!cache.contains(&task)) {

Check warning on line 80 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Check

unnecessary parentheses around `if` condition

Check warning on line 80 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unnecessary parentheses around `if` condition
token.extend(quote::quote!(
graph.add_node(Box::new(#task));
));
cache.insert(task);
}
for successor in relay.successors.iter() {
if (!cache.contains(successor)) {

Check warning on line 87 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Check

unnecessary parentheses around `if` condition

Check warning on line 87 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unnecessary parentheses around `if` condition
token.extend(quote::quote!(
graph.add_node(Box::new(#successor));
));
cache.insert(successor.clone());
}
}
}
token.extend(quote::quote!(for (key, value) in &edge {
let vec = value.iter().cloned().collect();
graph.add_edge(key.clone(), vec);
}));

quote::quote!(
{
#token;
graph
}
)
}
45 changes: 45 additions & 0 deletions examples/auto_relay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::sync::Arc;

use dagrs::{
auto_node, dependencies,
graph::{self, graph::Graph},

Check warning on line 5 in examples/auto_relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused imports: `EnvVar`, `graph::Graph`, and `self`
EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels,
};

#[auto_node]
struct MyNode {/*Put customized fields here.*/}

fn main() {
let mut node_table = NodeTable::default();

let node_name = "auto_node".to_string();

let s = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};

let a = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};

let b = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};
let mut g = dependencies!(s -> a b,
b -> a
);

g.run();
}
124 changes: 110 additions & 14 deletions src/graph/graph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::hash::Hash;
use std::sync::mpsc::channel;

Check warning on line 2 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Check

unused import: `std::sync::mpsc::channel`

Check warning on line 2 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused import: `std::sync::mpsc::channel`
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
panic::{self, AssertUnwindSafe},
sync::{atomic::AtomicBool, Arc},
};
Expand All @@ -8,6 +10,7 @@ use crate::{
connection::{in_channel::InChannel, information_packet::Content, out_channel::OutChannel},
node::node::{Node, NodeId, NodeTable},
utils::{env::EnvVar, execstate::ExecState},
Output,
};

use log::{debug, error};
Expand Down Expand Up @@ -46,6 +49,8 @@ pub struct Graph {
/// Mark whether the net task can continue to execute.
/// When an error occurs during the execution of any task, This flag will still be set to true
is_active: Arc<AtomicBool>,
/// Node's in_degree, used for check loop
in_degree: HashMap<NodeId, usize>,
}

impl Graph {
Expand All @@ -57,6 +62,7 @@ impl Graph {
execute_states: HashMap::new(),
env: Arc::new(EnvVar::new(NodeTable::default())),
is_active: Arc::new(AtomicBool::new(true)),
in_degree: HashMap::new(),
}
}

Expand All @@ -70,22 +76,29 @@ impl Graph {
/// Adds a new node to the `Graph`
pub fn add_node(&mut self, node: Box<dyn Node>) {
self.node_count = self.node_count + 1;
self.nodes.insert(node.id(), node);
let id = node.id();
self.nodes.insert(id, node);
self.in_degree.insert(id, 0);
}
/// Adds an edge between two nodes in the `Graph`.
/// If the outgoing port of the sending node is empty and the number of receiving nodes is > 1, use the broadcast channel
/// An MPSC channel is used if the outgoing port of the sending node is empty and the number of receiving nodes is equal to 1
/// If the outgoing port of the sending node is not empty, adding any number of receiving nodes will change all relevant channels to broadcast
pub fn add_edge(&mut self, from_id: NodeId, to_ids: Vec<NodeId>) {
pub fn add_edge(&mut self, from_id: NodeId, all_to_ids: Vec<NodeId>) {
let from_node = self.nodes.get_mut(&from_id).unwrap();
let from_channel = from_node.output_channels();
let to_ids = Self::remove_duplicates(all_to_ids);
if from_channel.0.is_empty() {
if to_ids.len() > 1 {
let (bcst_sender, _) = broadcast::channel::<Content>(32);
{
for to_id in &to_ids {
from_channel
.insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone())));
self.in_degree
.entry(*to_id)
.and_modify(|e| *e += 1)
.or_insert(0);
}
}
for to_id in &to_ids {
Expand All @@ -99,27 +112,42 @@ impl Graph {
let (tx, rx) = mpsc::channel::<Content>(32);
{
from_channel.insert(*to_id, Arc::new(OutChannel::Mpsc(tx.clone())));
self.in_degree
.entry(*to_id)
.and_modify(|e| *e += 1)
.or_insert(0);
}
if let Some(to_node) = self.nodes.get_mut(to_id) {
let to_channel = to_node.input_channels();
to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Mpsc(rx))));
}
}
} else {
let (bcst_sender, _) = broadcast::channel::<Content>(32);
if to_ids.len() > 1
|| (to_ids.len() == 1 && !from_channel.0.contains_key(to_ids.get(0).unwrap()))
{
for _channel in from_channel.0.values_mut() {
*_channel = Arc::new(OutChannel::Bcst(bcst_sender.clone()));
let (bcst_sender, _) = broadcast::channel::<Content>(32);
{
for _channel in from_channel.0.values_mut() {
*_channel = Arc::new(OutChannel::Bcst(bcst_sender.clone()));
}
for to_id in &to_ids {
if !from_channel.0.contains_key(to_id) {
self.in_degree
.entry(*to_id)
.and_modify(|e| *e += 1)
.or_insert(0);
}
from_channel
.insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone())));
}
}
for to_id in &to_ids {
from_channel.insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone())));
}
}
for to_id in &to_ids {
if let Some(to_node) = self.nodes.get_mut(to_id) {
let to_channel = to_node.input_channels();
let receiver = bcst_sender.subscribe();
to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Bcst(receiver))));
if let Some(to_node) = self.nodes.get_mut(to_id) {
let to_channel = to_node.input_channels();
let receiver = bcst_sender.subscribe();
to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Bcst(receiver))));
}
}
}
}
Expand All @@ -136,6 +164,10 @@ impl Graph {
/// This function is used for the execution of a single net.
pub fn run(&mut self) {
self.init();
let is_loop = self.check_loop();
if is_loop {
panic!("Graph contains a loop.");
}
if !self.is_active.load(std::sync::atomic::Ordering::Relaxed) {
eprintln!("Graph is not active. Aborting execution.");
return;
Expand Down Expand Up @@ -179,6 +211,70 @@ impl Graph {
self.is_active
.store(false, std::sync::atomic::Ordering::Relaxed);
}

///See if the graph has loop
pub fn check_loop(&mut self) -> bool {
let mut queue: Vec<NodeId> = self
.in_degree
.iter()
.filter_map(|(&node_id, &degree)| if degree == 0 { Some(node_id) } else { None })
.collect();

let mut in_degree = self.in_degree.clone();
let mut processed_count = 0;

while let Some(node_id) = queue.pop() {
processed_count += 1;
let node = self.nodes.get_mut(&node_id).unwrap();
let out = node.output_channels();
for (id, channel) in out.0.iter() {

Check warning on line 230 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Check

unused variable: `channel`

Check warning on line 230 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused variable: `channel`
if let Some(degree) = in_degree.get_mut(id) {
*degree -= 1;
if *degree == 0 {
queue.push(id.clone());
}
}
}
}
processed_count < self.node_count
}

/// Get the output of all tasks.
pub fn get_results<T: Send + Sync + 'static>(&self) -> HashMap<NodeId, Option<Arc<T>>> {
self.execute_states
.iter()
.map(|(&id, state)| {
let output = match state.get_output() {
Some(content) => content.into_inner(),
None => None,
};
(id, output)
})
.collect()
}
pub fn get_outputs(&self) -> HashMap<NodeId, Output> {
self.execute_states
.iter()
.map(|(&id, state)| {
let t = state.get_full_output();
(id, t)
})
.collect()
}

/// Before the dag starts executing, set the dag's global environment variable.
pub fn set_env(&mut self, env: EnvVar) {
self.env = Arc::new(env);
}

///Remove duplicate elements
fn remove_duplicates<T>(vec: Vec<T>) -> Vec<T>
where
T: Eq + Hash + Clone,
{
let mut seen = HashSet::new();
vec.into_iter().filter(|x| seen.insert(x.clone())).collect()
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 889b6de

Please sign in to comment.