From d10eac64931dac87d2b6455ff5abef2142ccdac5 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 1 Mar 2025 14:51:38 +0100 Subject: [PATCH 1/5] Adjust safety argument of `par_map` --- src/par_map.rs | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/src/par_map.rs b/src/par_map.rs index 1f93a1c5f..111b74c51 100644 --- a/src/par_map.rs +++ b/src/par_map.rs @@ -1,26 +1,24 @@ -use std::ops::Deref; - use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; use crate::Database; -pub fn par_map( - db: &Db, - inputs: impl IntoParallelIterator, - op: fn(&Db, D) -> E, -) -> C +pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C where Db: Database + ?Sized, - D: Send, - E: Send + Sync, - C: FromParallelIterator, + F: Fn(&Db, T) -> R + Sync + Send, + T: Send, + R: Send + Sync, + C: FromParallelIterator, { let parallel_db = ParallelDb::Ref(db.as_dyn_database()); inputs .into_par_iter() .map_with(parallel_db, |parallel_db, element| { - let db = parallel_db.as_view::(); + let db = match parallel_db { + ParallelDb::Ref(db) => db.as_view::(), + ParallelDb::Fork(db) => db.as_view::(), + }; op(db, element) }) .collect() @@ -29,26 +27,18 @@ where /// This enum _must not_ be public or used outside of `par_map`. enum ParallelDb<'db> { Ref(&'db dyn Database), - Fork(Box), + Fork(Box), } -/// SAFETY: the contents of the database are never accessed on the thread -/// where this wrapper type is created. -unsafe impl Send for ParallelDb<'_> {} - -impl Deref for ParallelDb<'_> { - type Target = dyn Database; - - fn deref(&self) -> &Self::Target { - match self { - ParallelDb::Ref(db) => *db, - ParallelDb::Fork(db) => db.as_dyn_database(), - } - } -} +/// SAFETY: We guarantee that the `&'db dyn Database` reference is not copied and as such it is +/// never referenced on multiple threads at once. +unsafe impl Send for ParallelDb<'_> where dyn Database: Send {} impl Clone for ParallelDb<'_> { fn clone(&self) -> Self { - ParallelDb::Fork(self.fork_db()) + ParallelDb::Fork(match self { + ParallelDb::Ref(db) => db.fork_db(), + ParallelDb::Fork(db) => db.fork_db(), + }) } } From 920bace98127ac3e914ecade5e6c7cc2d4ee8598 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 20 Mar 2025 14:08:33 +0100 Subject: [PATCH 2/5] More parallel APIs --- src/lib.rs | 2 +- src/par_map.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 88d6fbd78..90293db03 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ pub use self::update::Update; pub use self::zalsa::IngredientIndex; pub use crate::attach::with_attached_database; #[cfg(feature = "rayon")] -pub use par_map::par_map; +pub use par_map::{join, par_map, scope, Scope}; #[cfg(feature = "macros")] pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update}; diff --git a/src/par_map.rs b/src/par_map.rs index 111b74c51..f2b21592e 100644 --- a/src/par_map.rs +++ b/src/par_map.rs @@ -15,11 +15,7 @@ where inputs .into_par_iter() .map_with(parallel_db, |parallel_db, element| { - let db = match parallel_db { - ParallelDb::Ref(db) => db.as_view::(), - ParallelDb::Fork(db) => db.as_view::(), - }; - op(db, element) + op(parallel_db.as_view(), element) }) .collect() } @@ -34,6 +30,22 @@ enum ParallelDb<'db> { /// never referenced on multiple threads at once. unsafe impl Send for ParallelDb<'_> where dyn Database: Send {} +impl ParallelDb<'_> { + fn fork(&self) -> ParallelDb<'static> { + ParallelDb::Fork(match self { + ParallelDb::Ref(db) => db.fork_db(), + ParallelDb::Fork(db) => db.fork_db(), + }) + } + + fn as_view(&self) -> &Db { + match self { + ParallelDb::Ref(db) => db.as_view::(), + ParallelDb::Fork(db) => db.as_view::(), + } + } +} + impl Clone for ParallelDb<'_> { fn clone(&self) -> Self { ParallelDb::Fork(match self { @@ -42,3 +54,58 @@ impl Clone for ParallelDb<'_> { }) } } + +pub struct Scope<'scope, 'local, Db: Database + ?Sized> { + db: ParallelDb<'local>, + base: &'local rayon::Scope<'scope>, + phantom: std::marker::PhantomData Db>, +} + +impl<'scope, Db: Database + ?Sized> Scope<'scope, '_, Db> { + pub fn spawn(&self, body: BODY) + where + BODY: for<'l> FnOnce(&'l Scope<'scope, 'l, Db>, &Db) + Send + 'scope, + { + let db = self.db.fork(); + self.base.spawn(move |scope| { + let scope = Scope { + db, + base: scope, + phantom: std::marker::PhantomData, + }; + body(&scope, scope.db.as_view::()) + }) + } +} + +pub fn scope<'scope, Db: Database + ?Sized, OP, R>(db: &Db, op: OP) -> R +where + OP: FnOnce(&Scope<'scope, '_, Db>, &Db) -> R + Send, + R: Send, +{ + rayon::in_place_scope(move |s| { + let scope = Scope { + db: ParallelDb::Ref(db.as_dyn_database()), + base: s, + phantom: std::marker::PhantomData, + }; + op(&scope, db) + }) +} + +pub fn join(db: &Db, a: A, b: B) -> (RA, RB) +where + A: FnOnce(&Db) -> RA + Send, + B: FnOnce(&Db) -> RB + Send, + RA: Send, + RB: Send, +{ + // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get + // moved to another thread before the closure is executed + let db_a = db.fork_db(); + let db_b = db.fork_db(); + rayon::join( + move || a(db_a.as_view::()), + move || b(db_b.as_view::()), + ) +} From da68e9e436a6429578588929c94ecb16f7dfdfd9 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 22 Mar 2025 12:56:08 +0100 Subject: [PATCH 3/5] assert `ParallelDb` `Send` promise --- src/par_map.rs | 49 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/par_map.rs b/src/par_map.rs index f2b21592e..c4d368be9 100644 --- a/src/par_map.rs +++ b/src/par_map.rs @@ -1,3 +1,5 @@ +use std::thread::{self, ThreadId}; + use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; use crate::Database; @@ -10,7 +12,11 @@ where R: Send + Sync, C: FromParallelIterator, { - let parallel_db = ParallelDb::Ref(db.as_dyn_database()); + let parallel_db = ParallelDb::Ref( + db.as_dyn_database(), + #[cfg(debug_assertions)] + thread::current().id(), + ); inputs .into_par_iter() @@ -22,7 +28,11 @@ where /// This enum _must not_ be public or used outside of `par_map`. enum ParallelDb<'db> { - Ref(&'db dyn Database), + Ref( + &'db dyn Database, + #[cfg(debug_assertions)] ThreadId, + #[cfg(not(debug_assertions))] (), + ), Fork(Box), } @@ -31,16 +41,36 @@ enum ParallelDb<'db> { unsafe impl Send for ParallelDb<'_> where dyn Database: Send {} impl ParallelDb<'_> { + #[inline] + #[track_caller] + fn thread_id_assert(&self) { + #[cfg(debug_assertions)] + { + match self { + ParallelDb::Ref(_, thread_id) => { + assert_eq!( + thread_id, + &thread::current().id(), + "reference was smuggled to another thread!" + ); + } + ParallelDb::Fork(_) => {} + } + } + } + fn fork(&self) -> ParallelDb<'static> { + self.thread_id_assert(); ParallelDb::Fork(match self { - ParallelDb::Ref(db) => db.fork_db(), + ParallelDb::Ref(db, _) => db.fork_db(), ParallelDb::Fork(db) => db.fork_db(), }) } fn as_view(&self) -> &Db { + self.thread_id_assert(); match self { - ParallelDb::Ref(db) => db.as_view::(), + ParallelDb::Ref(db, _) => db.as_view::(), ParallelDb::Fork(db) => db.as_view::(), } } @@ -48,8 +78,9 @@ impl ParallelDb<'_> { impl Clone for ParallelDb<'_> { fn clone(&self) -> Self { + self.thread_id_assert(); ParallelDb::Fork(match self { - ParallelDb::Ref(db) => db.fork_db(), + ParallelDb::Ref(db, _) => db.fork_db(), ParallelDb::Fork(db) => db.fork_db(), }) } @@ -85,7 +116,13 @@ where { rayon::in_place_scope(move |s| { let scope = Scope { - db: ParallelDb::Ref(db.as_dyn_database()), + db: ParallelDb::Ref( + db.as_dyn_database(), + #[cfg(debug_assertions)] + thread::current().id(), + #[cfg(not(debug_assertions))] + (), + ), base: s, phantom: std::marker::PhantomData, }; From bc8e40679dd052f178e4bcab4b3726b895fb543e Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 23 Mar 2025 12:23:09 +0100 Subject: [PATCH 4/5] fix: Fix `par_map` unsoundness --- src/par_map.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/par_map.rs b/src/par_map.rs index c4d368be9..d47b725d0 100644 --- a/src/par_map.rs +++ b/src/par_map.rs @@ -12,20 +12,22 @@ where R: Send + Sync, C: FromParallelIterator, { - let parallel_db = ParallelDb::Ref( - db.as_dyn_database(), - #[cfg(debug_assertions)] - thread::current().id(), - ); - inputs .into_par_iter() - .map_with(parallel_db, |parallel_db, element| { - op(parallel_db.as_view(), element) + .map_with(DbForkOnClone(db.fork_db()), |db, element| { + op(db.0.as_view(), element) }) .collect() } +struct DbForkOnClone(Box); + +impl Clone for DbForkOnClone { + fn clone(&self) -> Self { + DbForkOnClone(self.0.fork_db()) + } +} + /// This enum _must not_ be public or used outside of `par_map`. enum ParallelDb<'db> { Ref( From b9379b0e5df364ced757e4cbd37ebdd87d9d27e0 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 24 Mar 2025 07:38:05 +0100 Subject: [PATCH 5/5] Add more parallel API tests --- src/lib.rs | 4 +- src/par_map.rs | 150 ------------------------------- src/parallel.rs | 77 ++++++++++++++++ tests/parallel/main.rs | 2 + tests/parallel/parallel_join.rs | 105 ++++++++++++++++++++++ tests/parallel/parallel_map.rs | 2 +- tests/parallel/parallel_scope.rs | 116 ++++++++++++++++++++++++ tests/parallel/setup.rs | 2 +- 8 files changed, 304 insertions(+), 154 deletions(-) delete mode 100644 src/par_map.rs create mode 100644 src/parallel.rs create mode 100644 tests/parallel/parallel_join.rs create mode 100644 tests/parallel/parallel_scope.rs diff --git a/src/lib.rs b/src/lib.rs index 90293db03..eb29ac954 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,7 @@ mod key; mod memo_ingredient_indices; mod nonce; #[cfg(feature = "rayon")] -mod par_map; +mod parallel; mod revision; mod runtime; mod salsa_struct; @@ -51,7 +51,7 @@ pub use self::update::Update; pub use self::zalsa::IngredientIndex; pub use crate::attach::with_attached_database; #[cfg(feature = "rayon")] -pub use par_map::{join, par_map, scope, Scope}; +pub use parallel::{join, par_map, scope, Scope}; #[cfg(feature = "macros")] pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update}; diff --git a/src/par_map.rs b/src/par_map.rs deleted file mode 100644 index d47b725d0..000000000 --- a/src/par_map.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::thread::{self, ThreadId}; - -use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; - -use crate::Database; - -pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C -where - Db: Database + ?Sized, - F: Fn(&Db, T) -> R + Sync + Send, - T: Send, - R: Send + Sync, - C: FromParallelIterator, -{ - inputs - .into_par_iter() - .map_with(DbForkOnClone(db.fork_db()), |db, element| { - op(db.0.as_view(), element) - }) - .collect() -} - -struct DbForkOnClone(Box); - -impl Clone for DbForkOnClone { - fn clone(&self) -> Self { - DbForkOnClone(self.0.fork_db()) - } -} - -/// This enum _must not_ be public or used outside of `par_map`. -enum ParallelDb<'db> { - Ref( - &'db dyn Database, - #[cfg(debug_assertions)] ThreadId, - #[cfg(not(debug_assertions))] (), - ), - Fork(Box), -} - -/// SAFETY: We guarantee that the `&'db dyn Database` reference is not copied and as such it is -/// never referenced on multiple threads at once. -unsafe impl Send for ParallelDb<'_> where dyn Database: Send {} - -impl ParallelDb<'_> { - #[inline] - #[track_caller] - fn thread_id_assert(&self) { - #[cfg(debug_assertions)] - { - match self { - ParallelDb::Ref(_, thread_id) => { - assert_eq!( - thread_id, - &thread::current().id(), - "reference was smuggled to another thread!" - ); - } - ParallelDb::Fork(_) => {} - } - } - } - - fn fork(&self) -> ParallelDb<'static> { - self.thread_id_assert(); - ParallelDb::Fork(match self { - ParallelDb::Ref(db, _) => db.fork_db(), - ParallelDb::Fork(db) => db.fork_db(), - }) - } - - fn as_view(&self) -> &Db { - self.thread_id_assert(); - match self { - ParallelDb::Ref(db, _) => db.as_view::(), - ParallelDb::Fork(db) => db.as_view::(), - } - } -} - -impl Clone for ParallelDb<'_> { - fn clone(&self) -> Self { - self.thread_id_assert(); - ParallelDb::Fork(match self { - ParallelDb::Ref(db, _) => db.fork_db(), - ParallelDb::Fork(db) => db.fork_db(), - }) - } -} - -pub struct Scope<'scope, 'local, Db: Database + ?Sized> { - db: ParallelDb<'local>, - base: &'local rayon::Scope<'scope>, - phantom: std::marker::PhantomData Db>, -} - -impl<'scope, Db: Database + ?Sized> Scope<'scope, '_, Db> { - pub fn spawn(&self, body: BODY) - where - BODY: for<'l> FnOnce(&'l Scope<'scope, 'l, Db>, &Db) + Send + 'scope, - { - let db = self.db.fork(); - self.base.spawn(move |scope| { - let scope = Scope { - db, - base: scope, - phantom: std::marker::PhantomData, - }; - body(&scope, scope.db.as_view::()) - }) - } -} - -pub fn scope<'scope, Db: Database + ?Sized, OP, R>(db: &Db, op: OP) -> R -where - OP: FnOnce(&Scope<'scope, '_, Db>, &Db) -> R + Send, - R: Send, -{ - rayon::in_place_scope(move |s| { - let scope = Scope { - db: ParallelDb::Ref( - db.as_dyn_database(), - #[cfg(debug_assertions)] - thread::current().id(), - #[cfg(not(debug_assertions))] - (), - ), - base: s, - phantom: std::marker::PhantomData, - }; - op(&scope, db) - }) -} - -pub fn join(db: &Db, a: A, b: B) -> (RA, RB) -where - A: FnOnce(&Db) -> RA + Send, - B: FnOnce(&Db) -> RB + Send, - RA: Send, - RB: Send, -{ - // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get - // moved to another thread before the closure is executed - let db_a = db.fork_db(); - let db_b = db.fork_db(); - rayon::join( - move || a(db_a.as_view::()), - move || b(db_b.as_view::()), - ) -} diff --git a/src/parallel.rs b/src/parallel.rs new file mode 100644 index 000000000..d3033af5b --- /dev/null +++ b/src/parallel.rs @@ -0,0 +1,77 @@ +use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; + +use crate::Database; + +pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C +where + Db: Database + ?Sized, + F: Fn(&Db, T) -> R + Sync + Send, + T: Send, + R: Send + Sync, + C: FromParallelIterator, +{ + inputs + .into_par_iter() + .map_with(DbForkOnClone(db.fork_db()), |db, element| { + op(db.0.as_view(), element) + }) + .collect() +} + +struct DbForkOnClone(Box); + +impl Clone for DbForkOnClone { + fn clone(&self) -> Self { + DbForkOnClone(self.0.fork_db()) + } +} + +pub struct Scope<'scope, 'local, Db: Database + ?Sized> { + db: &'local Db, + base: &'local rayon::Scope<'scope>, +} + +impl<'scope, 'local, Db: Database + ?Sized> Scope<'scope, 'local, Db> { + pub fn spawn(&self, body: BODY) + where + BODY: for<'l> FnOnce(&'l Scope<'scope, 'l, Db>) + Send + 'scope, + { + let db = self.db.fork_db(); + self.base.spawn(move |scope| { + let scope = Scope { + db: db.as_view::(), + base: scope, + }; + body(&scope) + }) + } + + pub fn db(&self) -> &'local Db { + self.db + } +} + +pub fn scope<'scope, Db: Database + ?Sized, OP, R>(db: &Db, op: OP) -> R +where + OP: FnOnce(&Scope<'scope, '_, Db>) -> R + Send, + R: Send, +{ + rayon::in_place_scope(move |s| op(&Scope { db, base: s })) +} + +pub fn join(db: &Db, a: A, b: B) -> (RA, RB) +where + A: FnOnce(&Db) -> RA + Send, + B: FnOnce(&Db) -> RB + Send, + RA: Send, + RB: Send, +{ + // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get + // moved to another thread before the closure is executed + let db_a = db.fork_db(); + let db_b = db.fork_db(); + rayon::join( + move || a(db_a.as_view::()), + move || b(db_b.as_view::()), + ) +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index cf02f64ac..4f4295463 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -5,5 +5,7 @@ mod cycle_ab_peeping_c; mod cycle_nested_three_threads; mod cycle_panic; mod parallel_cancellation; +mod parallel_join; mod parallel_map; +mod parallel_scope; mod signal; diff --git a/tests/parallel/parallel_join.rs b/tests/parallel/parallel_join.rs new file mode 100644 index 000000000..406d4f0bd --- /dev/null +++ b/tests/parallel/parallel_join.rs @@ -0,0 +1,105 @@ +#![cfg(feature = "rayon")] +// test for rayon-like join interactions. + +use salsa::Cancelled; +use salsa::Setter; + +use crate::setup::Knobs; +use crate::setup::KnobsDatabase; + +#[salsa::input] +struct ParallelInput { + a: u32, + b: u32, +} + +#[salsa::tracked] +fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) { + salsa::join(db, |db| input.a(db) + 1, |db| input.b(db) - 1) +} + +#[salsa::tracked] +fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) { + db.signal(1); + salsa::join( + db, + |db| { + db.wait_for(2); + input.a(db) + dummy(db) + }, + |db| { + db.wait_for(2); + input.b(db) + dummy(db) + }, + ) +} + +#[salsa::tracked] +fn dummy(_db: &dyn KnobsDatabase) -> u32 { + panic!("should never get here!") +} + +#[test] +#[cfg_attr(miri, ignore)] +fn execute() { + let db = salsa::DatabaseImpl::new(); + + let input = ParallelInput::new(&db, 10, 20); + + tracked_fn(&db, input); +} + +// we expect this to panic, as `salsa::par_map` needs to be called from a query. +#[test] +#[cfg_attr(miri, ignore)] +#[should_panic] +fn direct_calls_panic() { + let db = salsa::DatabaseImpl::new(); + + let input = ParallelInput::new(&db, 10, 20); + let (_, _) = salsa::join(&db, |db| input.a(db) + 1, |db| input.b(db) - 1); +} + +// Cancellation signalling test +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 +// | wait for stage 1 +// signal stage 1 set input, triggers cancellation +// wait for stage 2 (blocks) triggering cancellation sends stage 2 +// | +// (unblocked) +// dummy +// panics + +#[test] +#[cfg_attr(miri, ignore)] +fn execute_cancellation() { + let mut db = Knobs::default(); + + let input = ParallelInput::new(&db, 10, 20); + + let thread_a = std::thread::spawn({ + let db = db.clone(); + move || a1(&db, input) + }); + + db.signal_on_did_cancel(2); + input.set_a(&mut db).to(30); + + // Assert thread A was cancelled + let cancelled = thread_a + .join() + .unwrap_err() + .downcast::() + .unwrap(); + + // and inspect the output + expect_test::expect![[r#" + PendingWrite + "#]] + .assert_debug_eq(&cancelled); +} diff --git a/tests/parallel/parallel_map.rs b/tests/parallel/parallel_map.rs index 5d5bdbb18..2aa7def07 100644 --- a/tests/parallel/parallel_map.rs +++ b/tests/parallel/parallel_map.rs @@ -1,5 +1,5 @@ #![cfg(feature = "rayon")] -// test for rayon interactions. +// test for rayon-like parallel map interactions. use salsa::Cancelled; use salsa::Setter; diff --git a/tests/parallel/parallel_scope.rs b/tests/parallel/parallel_scope.rs new file mode 100644 index 000000000..930f7c569 --- /dev/null +++ b/tests/parallel/parallel_scope.rs @@ -0,0 +1,116 @@ +#![cfg(feature = "rayon")] +// test for rayon-like scope interactions. + +use salsa::Cancelled; +use salsa::Setter; + +use crate::setup::Knobs; +use crate::setup::KnobsDatabase; + +#[salsa::input] +struct ParallelInput { + a: u32, + b: u32, +} + +#[salsa::tracked] +fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) { + let mut a = None; + let mut b = None; + salsa::scope(db, |scope| { + scope.spawn(|scope| a = Some(input.a(scope.db()) + 1)); + scope.spawn(|scope| b = Some(input.b(scope.db()) + 1)); + }); + (a.unwrap(), b.unwrap()) +} + +#[salsa::tracked] +fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) { + db.signal(1); + let mut a = None; + let mut b = None; + salsa::scope(db, |scope| { + scope.spawn(|scope| { + scope.db().wait_for(2); + a = Some(input.a(scope.db()) + 1) + }); + scope.spawn(|scope| { + scope.db().wait_for(2); + b = Some(input.b(scope.db()) + 1) + }); + }); + (a.unwrap(), b.unwrap()) +} + +#[salsa::tracked] +fn dummy(_db: &dyn KnobsDatabase) -> u32 { + panic!("should never get here!") +} + +#[test] +#[cfg_attr(miri, ignore)] +fn execute() { + let db = salsa::DatabaseImpl::new(); + + let input = ParallelInput::new(&db, 10, 20); + + tracked_fn(&db, input); +} + +// we expect this to panic, as `salsa::par_map` needs to be called from a query. +#[test] +#[cfg_attr(miri, ignore)] +#[should_panic] +fn direct_calls_panic() { + let db = salsa::DatabaseImpl::new(); + + let input = ParallelInput::new(&db, 10, 20); + salsa::scope(&db, |scope| { + scope.spawn(|scope| _ = input.a(scope.db()) + 1); + scope.spawn(|scope| _ = input.b(scope.db()) + 1); + }); +} + +// Cancellation signalling test +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 +// | wait for stage 1 +// signal stage 1 set input, triggers cancellation +// wait for stage 2 (blocks) triggering cancellation sends stage 2 +// | +// (unblocked) +// dummy +// panics + +#[test] +#[cfg_attr(miri, ignore)] +fn execute_cancellation() { + let mut db = Knobs::default(); + + let input = ParallelInput::new(&db, 10, 20); + + let thread_a = std::thread::spawn({ + let db = db.clone(); + move || a1(&db, input) + }); + + db.signal_on_did_cancel(2); + input.set_a(&mut db).to(30); + + // Assert thread A was cancelled + let cancelled = thread_a + .join() + .unwrap_err() + .downcast::() + .unwrap(); + + // and inspect the output + expect_test::expect![[r#" + PendingWrite + "#]] + .assert_debug_eq(&cancelled); +} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 52c0ce227..70b1a25a8 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -11,7 +11,7 @@ use crate::signal::Signal; /// a certain behavior. #[salsa::db] pub(crate) trait KnobsDatabase: Database { - /// Signal that we are entering stage 1. + /// Signal that we are entering stage `stage`. fn signal(&self, stage: usize); /// Wait until we reach stage `stage` (no-op if we have already reached that stage).