Skip to content

Commit

Permalink
fix!: return Arc instead of own in get_result(s)
Browse files Browse the repository at this point in the history
BREAKING this returns an Arc<T> instead of T in the get_result methods

Signed-off-by: Amin Yahyaabadi <[email protected]>
  • Loading branch information
aminya committed Nov 27, 2023
1 parent 3cf6d2d commit 2923117
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ fn main() {
// Start executing this dag
assert!(dag.start().unwrap());
// Get execution result.
let res = dag.into_result::<usize>().unwrap();
let res = dag.get_result::<usize>().unwrap();
println!("The result is {}.", res);
}
```
Expand Down
2 changes: 1 addition & 1 deletion benches/compute_dag_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn compute_dag(tasks: Vec<DefaultTask>) {

assert!(dag.start().unwrap());
// Get execution result.
let _res = dag.into_result::<usize>().unwrap();
let _res = dag.get_result::<usize>().unwrap();
}

fn compute_dag_bench(bencher: &mut Criterion) {
Expand Down
2 changes: 1 addition & 1 deletion examples/compute_dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ fn main() {
// Start executing this dag
assert!(dag.start().unwrap());
// Get execution result.
let res = dag.into_result::<usize>().unwrap();
let res = dag.get_result::<usize>().unwrap();
println!("The result is {}.", res);
}
10 changes: 8 additions & 2 deletions examples/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ fn main() {
// Execute dag in order, the order should be dag1, dag2, dag3.
assert_eq!(engine.run_sequential(), vec![true, true, true]);
// Get the execution results of dag1 and dag2.
assert_eq!(engine.get_dag_result::<usize>("graph1").unwrap(), 100);
assert_eq!(engine.get_dag_result::<usize>("graph2").unwrap(), 1024);
assert_eq!(
engine.get_dag_result::<usize>("graph1").unwrap().as_ref(),
&100
);
assert_eq!(
engine.get_dag_result::<usize>("graph2").unwrap().as_ref(),
&1024
);
}
32 changes: 2 additions & 30 deletions src/engine/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ impl Dag {
}

/// Get the final execution result.
pub fn into_result<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
pub fn get_result<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
if self.exe_sequence.is_empty() {
None
} else {
Expand All @@ -352,15 +352,8 @@ impl Dag {
}
}

/// Get the final execution result.
///
/// Note: This method might clone the value if there are references to the value. To avoid cloning, use [`Dag::into_result`].
pub fn get_result<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
self.into_result::<T>().map(unwrap_or_clone)
}

/// Get the output of all tasks.
pub fn into_results<T: Send + Sync + 'static>(&self) -> HashMap<usize, Option<Arc<T>>> {
pub fn get_results<T: Send + Sync + 'static>(&self) -> HashMap<usize, Option<Arc<T>>> {
let mut hm = HashMap::new();
for (id, state) in &self.execute_states {
let output = match state.get_output() {
Expand All @@ -372,29 +365,8 @@ impl Dag {
hm
}

/// Get the output of all tasks.
///
/// Note: This method might clone the value if there are references to the value. To avoid cloning, use [`Dag::into_results`].
pub fn get_results<T: Send + Sync + Clone + 'static>(&self) -> HashMap<usize, Option<T>> {
let mut hm = HashMap::new();
for (id, state) in &self.execute_states {
let output = match state.get_output() {
Some(content) => content.into_inner().map(unwrap_or_clone),
None => None,
};
hm.insert(*id, output);
}
hm
}

/// Before the dag starts executing, set the dag's global environment variable.
pub fn set_env(&mut self, env: EnvVar) {
self.env = Arc::new(env);
}
}

/// Unwrap an Arc<T> to T if possible, otherwise clone the value.
// This can be removed when Arc::unwrap_or_clone is stabilized.
fn unwrap_or_clone<T: Send + Sync + Clone + 'static>(arc: Arc<T>) -> T {
Arc::try_unwrap(arc).unwrap_or_else(|arc| (*arc).clone())
}
4 changes: 2 additions & 2 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod dag;
mod graph;

use crate::ParseError;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tokio::runtime::Runtime;

/// The Engine. Manage multiple Dags.
Expand Down Expand Up @@ -91,7 +91,7 @@ impl Engine {
}

/// Given the name of the Dag, get the execution result of the specified Dag.
pub fn get_dag_result<T: Send + Sync + Clone + 'static>(&self, name: &str) -> Option<T> {
pub fn get_dag_result<T: Send + Sync + Clone + 'static>(&self, name: &str) -> Option<Arc<T>> {
if self.dags.contains_key(name) {
self.dags.get(name).unwrap().get_result()
} else {
Expand Down

0 comments on commit 2923117

Please sign in to comment.