Skip to content

Commit

Permalink
Merge pull request #44 from aminya/clone
Browse files Browse the repository at this point in the history
fix!: avoid cloning the input/outputs + remove anymap2
  • Loading branch information
genedna authored Nov 27, 2023
2 parents a4b8255 + d971b03 commit 8bbb037
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 33 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ members = ["derive", "."]
yaml-rust = { version = "0.4.5", optional = true }
bimap = "0.6.1"
clap = { version = "4.2.2", features = ["derive"] }
anymap2 = "0.13.0"
tokio = { version = "1.28", features = ["rt", "sync", "rt-multi-thread"] }
derive = { path = "derive", version = "0.3.0" }
thiserror = "1.0.50"
Expand Down
2 changes: 2 additions & 0 deletions benches/compute_dag_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ fn compute_dag(tasks: Vec<DefaultTask>) {
dag.set_env(env);

assert!(dag.start().unwrap());
// Get execution result.
let _res = dag.get_result::<usize>().unwrap();
}

fn compute_dag_bench(bencher: &mut Criterion) {
Expand Down
10 changes: 8 additions & 2 deletions examples/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ fn main() {
// Execute dag in order, the order should be dag1, dag2, dag3.
assert_eq!(engine.run_sequential(), vec![true, true, true]);
// Get the execution results of dag1 and dag2.
assert_eq!(engine.get_dag_result::<usize>("graph1").unwrap(), 100);
assert_eq!(engine.get_dag_result::<usize>("graph2").unwrap(), 1024);
assert_eq!(
engine.get_dag_result::<usize>("graph1").unwrap().as_ref(),
&100
);
assert_eq!(
engine.get_dag_result::<usize>("graph2").unwrap().as_ref(),
&1024
);
}
13 changes: 5 additions & 8 deletions src/engine/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
utils::EnvVar,
Action, Parser,
};
use anymap2::any::CloneAnySendSync;
use log::{debug, error};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -274,9 +273,7 @@ impl Dag {
return true;
}
if let Some(content) = wait_for.get_output() {
if !content.is_empty() {
inputs.push(content);
}
inputs.push(content);
}
}
debug!("Executing task [name: {}, id: {}]", task_name, task_id);
Expand Down Expand Up @@ -343,24 +340,24 @@ impl Dag {
}

/// Get the final execution result.
pub fn get_result<T: CloneAnySendSync + Send + Sync>(&self) -> Option<T> {
pub fn get_result<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
if self.exe_sequence.is_empty() {
None
} else {
let last_id = self.exe_sequence.last().unwrap();
match self.execute_states[last_id].get_output() {
Some(ref content) => content.clone().remove(),
Some(content) => content.into_inner(),
None => None,
}
}
}

/// Get the output of all tasks.
pub fn get_results<T: CloneAnySendSync + Send + Sync>(&self) -> HashMap<usize, Option<T>> {
pub fn get_results<T: Send + Sync + 'static>(&self) -> HashMap<usize, Option<Arc<T>>> {
let mut hm = HashMap::new();
for (id, state) in &self.execute_states {
let output = match state.get_output() {
Some(ref content) => content.clone().remove(),
Some(content) => content.into_inner(),
None => None,
};
hm.insert(*id, output);
Expand Down
5 changes: 2 additions & 3 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ mod dag;
mod graph;

use crate::ParseError;
use anymap2::any::CloneAnySendSync;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tokio::runtime::Runtime;

/// The Engine. Manage multiple Dags.
Expand Down Expand Up @@ -92,7 +91,7 @@ impl Engine {
}

/// Given the name of the Dag, get the execution result of the specified Dag.
pub fn get_dag_result<T: CloneAnySendSync + Send + Sync>(&self, name: &str) -> Option<T> {
pub fn get_dag_result<T: Send + Sync + Clone + 'static>(&self, name: &str) -> Option<Arc<T>> {
if self.dags.contains_key(name) {
self.dags.get(name).unwrap().get_result()
} else {
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
extern crate anymap2;
extern crate bimap;
extern crate clap;
#[cfg(feature = "derive")]
Expand Down
2 changes: 1 addition & 1 deletion src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::sync::atomic::AtomicUsize;
pub use self::action::{Action, Complex, Simple};
pub use self::cmd::CommandAction;
pub use self::default_task::DefaultTask;
pub(crate) use self::state::ExecState;
pub(crate) use self::state::{ExecState, Content};
pub use self::state::{Input, Output};

mod action;
Expand Down
41 changes: 33 additions & 8 deletions src/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,42 @@
//! to implement the logic of the program.
use std::{
any::Any,
slice::Iter,
sync::atomic::{AtomicBool, AtomicPtr, Ordering},
sync::{
atomic::{AtomicBool, AtomicPtr, Ordering},
Arc,
},
};

use anymap2::{any::CloneAnySendSync, Map};
use tokio::sync::Semaphore;

/// Container type to store task output.
type Content = Map<dyn CloneAnySendSync + Send + Sync>;
#[derive(Debug, Clone)]
pub struct Content {
content: Arc<dyn Any + Send + Sync>,
}

impl Content {
/// Construct a new [`Content`].
pub fn new<H: Send + Sync + 'static>(val: H) -> Self {
Self {
content: Arc::new(val),
}
}

pub fn from_arc<H: Send + Sync + 'static>(val: Arc<H>) -> Self {
Self { content: val }
}

pub fn get<H: 'static>(&self) -> Option<&H> {
self.content.downcast_ref::<H>()
}

pub fn into_inner<H: Send + Sync + 'static>(self) -> Option<Arc<H>> {
self.content.downcast::<H>().ok()
}
}

/// [`ExeState`] internally stores [`Output`], which represents whether the execution of
/// the task is successful, and its internal semaphore is used to synchronously obtain
Expand Down Expand Up @@ -149,11 +176,9 @@ impl Output {
/// Construct a new [`Output`].
///
/// Since the return value may be transferred between threads,
/// [`Send`], [`Sync`], [`CloneAnySendSync`] is needed.
pub fn new<H: Send + Sync + CloneAnySendSync>(val: H) -> Self {
let mut map = Content::new();
assert!(map.insert(val).is_none(), "[Error] map insert fails.");
Self::Out(Some(map))
/// [`Send`], [`Sync`] is needed.
pub fn new<H: Send + Sync + 'static>(val: H) -> Self {
Self::Out(Some(Content::new(val)))
}

/// Construct an empty [`Output`].
Expand Down
22 changes: 14 additions & 8 deletions src/utils/env.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anymap2::{any::CloneAnySendSync, Map};
use crate::task::Content;

use std::collections::HashMap;

pub type Variable = Map<dyn CloneAnySendSync + Send + Sync>;
pub type Variable = Content;

/// # Environment variable.
///
Expand Down Expand Up @@ -31,17 +32,22 @@ impl EnvVar {
/// # let mut env = dagrs::EnvVar::new();
/// env.set("Hello", "World".to_string());
/// ```
pub fn set<H: Send + Sync + CloneAnySendSync>(&mut self, name: &str, var: H) {
let mut v = Variable::new();
v.insert(var);
pub fn set<H: Send + Sync + 'static>(&mut self, name: &str, var: H) {
let mut v = Variable::new(var);
self.variables.insert(name.to_owned(), v);
}

#[allow(unused)]
/// Get environment variables through keys of type &str.
pub fn get<H: Send + Sync + CloneAnySendSync>(&self, name: &str) -> Option<H> {
///
/// Note: This method will clone the value. To avoid cloning, use [`get_ref`].
pub fn get<H: Send + Sync + Clone + 'static>(&self, name: &str) -> Option<H> {
self.get_ref(name).cloned()
}

/// Get environment variables through keys of type &str.
pub fn get_ref<H: Send + Sync + 'static>(&self, name: &str) -> Option<&H> {
if let Some(content) = self.variables.get(name) {
content.clone().remove()
content.get()
} else {
None
}
Expand Down
2 changes: 1 addition & 1 deletion tests/env_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn env_set_get_test() {
let env = init_env();
assert_eq!(env.get::<usize>("test1"), Some(1usize));
assert_eq!(env.get::<usize>("test2"), None);
assert_eq!(env.get::<String>("test3"), Some("3".to_string()))
assert_eq!(env.get_ref::<String>("test3"), Some(&"3".to_string()))
}

#[test]
Expand Down

0 comments on commit 8bbb037

Please sign in to comment.