Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: avoid cloning the input/outputs + remove anymap2 #44

Merged
merged 9 commits into from
Nov 27, 2023
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ members = ["derive", "."]
yaml-rust = { version = "0.4.5", optional = true }
bimap = "0.6.1"
clap = { version = "4.2.2", features = ["derive"] }
anymap2 = "0.13.0"
tokio = { version = "1.28", features = ["rt", "sync", "rt-multi-thread"] }
derive = { path = "derive", version = "0.3.0" }
thiserror = "1.0.50"
Expand Down
2 changes: 2 additions & 0 deletions benches/compute_dag_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ fn compute_dag(tasks: Vec<DefaultTask>) {
dag.set_env(env);

assert!(dag.start().unwrap());
// Get execution result.
let _res = dag.get_result::<usize>().unwrap();
}

fn compute_dag_bench(bencher: &mut Criterion) {
Expand Down
10 changes: 8 additions & 2 deletions examples/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ fn main() {
// Execute dag in order, the order should be dag1, dag2, dag3.
assert_eq!(engine.run_sequential(), vec![true, true, true]);
// Get the execution results of dag1 and dag2.
assert_eq!(engine.get_dag_result::<usize>("graph1").unwrap(), 100);
assert_eq!(engine.get_dag_result::<usize>("graph2").unwrap(), 1024);
assert_eq!(
engine.get_dag_result::<usize>("graph1").unwrap().as_ref(),
&100
);
assert_eq!(
engine.get_dag_result::<usize>("graph2").unwrap().as_ref(),
&1024
);
}
13 changes: 5 additions & 8 deletions src/engine/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
utils::EnvVar,
Action, Parser,
};
use anymap2::any::CloneAnySendSync;
use log::{debug, error};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -274,9 +273,7 @@ impl Dag {
return true;
}
if let Some(content) = wait_for.get_output() {
if !content.is_empty() {
inputs.push(content);
}
inputs.push(content);
}
}
debug!("Executing task [name: {}, id: {}]", task_name, task_id);
Expand Down Expand Up @@ -343,24 +340,24 @@ impl Dag {
}

/// Get the final execution result.
pub fn get_result<T: CloneAnySendSync + Send + Sync>(&self) -> Option<T> {
pub fn get_result<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
if self.exe_sequence.is_empty() {
None
} else {
let last_id = self.exe_sequence.last().unwrap();
match self.execute_states[last_id].get_output() {
Some(ref content) => content.clone().remove(),
Some(content) => content.into_inner(),
None => None,
}
}
}

/// Get the output of all tasks.
pub fn get_results<T: CloneAnySendSync + Send + Sync>(&self) -> HashMap<usize, Option<T>> {
pub fn get_results<T: Send + Sync + 'static>(&self) -> HashMap<usize, Option<Arc<T>>> {
let mut hm = HashMap::new();
for (id, state) in &self.execute_states {
let output = match state.get_output() {
Some(ref content) => content.clone().remove(),
Some(content) => content.into_inner(),
None => None,
};
hm.insert(*id, output);
Expand Down
5 changes: 2 additions & 3 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ mod dag;
mod graph;

use crate::ParseError;
use anymap2::any::CloneAnySendSync;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tokio::runtime::Runtime;

/// The Engine. Manage multiple Dags.
Expand Down Expand Up @@ -92,7 +91,7 @@ impl Engine {
}

/// Given the name of the Dag, get the execution result of the specified Dag.
pub fn get_dag_result<T: CloneAnySendSync + Send + Sync>(&self, name: &str) -> Option<T> {
pub fn get_dag_result<T: Send + Sync + Clone + 'static>(&self, name: &str) -> Option<Arc<T>> {
if self.dags.contains_key(name) {
self.dags.get(name).unwrap().get_result()
} else {
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
extern crate anymap2;
extern crate bimap;
extern crate clap;
#[cfg(feature = "derive")]
Expand Down
2 changes: 1 addition & 1 deletion src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::sync::atomic::AtomicUsize;
pub use self::action::{Action, Complex, Simple};
pub use self::cmd::CommandAction;
pub use self::default_task::DefaultTask;
pub(crate) use self::state::ExecState;
pub(crate) use self::state::{ExecState, Content};
pub use self::state::{Input, Output};

mod action;
Expand Down
41 changes: 33 additions & 8 deletions src/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,42 @@
//! to implement the logic of the program.
use std::{
any::Any,
slice::Iter,
sync::atomic::{AtomicBool, AtomicPtr, Ordering},
sync::{
atomic::{AtomicBool, AtomicPtr, Ordering},
Arc,
},
};

use anymap2::{any::CloneAnySendSync, Map};
use tokio::sync::Semaphore;

/// Container type to store task output.
type Content = Map<dyn CloneAnySendSync + Send + Sync>;
#[derive(Debug, Clone)]
pub struct Content {
content: Arc<dyn Any + Send + Sync>,
}

impl Content {
/// Construct a new [`Content`].
pub fn new<H: Send + Sync + 'static>(val: H) -> Self {
Self {
content: Arc::new(val),
}
}

pub fn from_arc<H: Send + Sync + 'static>(val: Arc<H>) -> Self {
Self { content: val }
}

pub fn get<H: 'static>(&self) -> Option<&H> {
self.content.downcast_ref::<H>()
}

pub fn into_inner<H: Send + Sync + 'static>(self) -> Option<Arc<H>> {
self.content.downcast::<H>().ok()
}
}

/// [`ExeState`] internally stores [`Output`], which represents whether the execution of
/// the task is successful, and its internal semaphore is used to synchronously obtain
Expand Down Expand Up @@ -149,11 +176,9 @@ impl Output {
/// Construct a new [`Output`].
///
/// Since the return value may be transferred between threads,
/// [`Send`], [`Sync`], [`CloneAnySendSync`] is needed.
pub fn new<H: Send + Sync + CloneAnySendSync>(val: H) -> Self {
let mut map = Content::new();
assert!(map.insert(val).is_none(), "[Error] map insert fails.");
Self::Out(Some(map))
/// [`Send`], [`Sync`] is needed.
pub fn new<H: Send + Sync + 'static>(val: H) -> Self {
Self::Out(Some(Content::new(val)))
}

/// Construct an empty [`Output`].
Expand Down
22 changes: 14 additions & 8 deletions src/utils/env.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anymap2::{any::CloneAnySendSync, Map};
use crate::task::Content;

use std::collections::HashMap;

pub type Variable = Map<dyn CloneAnySendSync + Send + Sync>;
pub type Variable = Content;

/// # Environment variable.
///
Expand Down Expand Up @@ -31,17 +32,22 @@ impl EnvVar {
/// # let mut env = dagrs::EnvVar::new();
/// env.set("Hello", "World".to_string());
/// ```
pub fn set<H: Send + Sync + CloneAnySendSync>(&mut self, name: &str, var: H) {
let mut v = Variable::new();
v.insert(var);
pub fn set<H: Send + Sync + 'static>(&mut self, name: &str, var: H) {
let mut v = Variable::new(var);
self.variables.insert(name.to_owned(), v);
}

#[allow(unused)]
/// Get environment variables through keys of type &str.
pub fn get<H: Send + Sync + CloneAnySendSync>(&self, name: &str) -> Option<H> {
///
/// Note: This method will clone the value. To avoid cloning, use [`get_ref`].
pub fn get<H: Send + Sync + Clone + 'static>(&self, name: &str) -> Option<H> {
self.get_ref(name).cloned()
}

/// Get environment variables through keys of type &str.
pub fn get_ref<H: Send + Sync + 'static>(&self, name: &str) -> Option<&H> {
if let Some(content) = self.variables.get(name) {
content.clone().remove()
content.get()
} else {
None
}
Expand Down
2 changes: 1 addition & 1 deletion tests/env_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn env_set_get_test() {
let env = init_env();
assert_eq!(env.get::<usize>("test1"), Some(1usize));
assert_eq!(env.get::<usize>("test2"), None);
assert_eq!(env.get::<String>("test3"), Some("3".to_string()))
assert_eq!(env.get_ref::<String>("test3"), Some(&"3".to_string()))
}

#[test]
Expand Down