Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod key;
mod memo_ingredient_indices;
mod nonce;
#[cfg(feature = "rayon")]
mod par_map;
mod parallel;
mod revision;
mod runtime;
mod salsa_struct;
Expand Down Expand Up @@ -51,7 +51,7 @@ pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use crate::attach::with_attached_database;
#[cfg(feature = "rayon")]
pub use par_map::par_map;
pub use parallel::{join, par_map, scope, Scope};
#[cfg(feature = "macros")]
pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update};

Expand Down
54 changes: 0 additions & 54 deletions src/par_map.rs

This file was deleted.

77 changes: 77 additions & 0 deletions src/parallel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::Database;

pub fn par_map<Db, F, T, R, C>(db: &Db, inputs: impl IntoParallelIterator<Item = T>, op: F) -> C
where
Db: Database + ?Sized,
F: Fn(&Db, T) -> R + Sync + Send,
T: Send,
R: Send + Sync,
C: FromParallelIterator<R>,
{
inputs
.into_par_iter()
.map_with(DbForkOnClone(db.fork_db()), |db, element| {
op(db.0.as_view(), element)
})
.collect()
}

struct DbForkOnClone(Box<dyn Database>);

impl Clone for DbForkOnClone {
fn clone(&self) -> Self {
DbForkOnClone(self.0.fork_db())
}
}

pub struct Scope<'scope, 'local, Db: Database + ?Sized> {
db: &'local Db,
base: &'local rayon::Scope<'scope>,
}

impl<'scope, 'local, Db: Database + ?Sized> Scope<'scope, 'local, Db> {
pub fn spawn<BODY>(&self, body: BODY)
where
BODY: for<'l> FnOnce(&'l Scope<'scope, 'l, Db>) + Send + 'scope,
{
let db = self.db.fork_db();
self.base.spawn(move |scope| {
let scope = Scope {
db: db.as_view::<Db>(),
base: scope,
};
body(&scope)
})
}

pub fn db(&self) -> &'local Db {
self.db
}
}

pub fn scope<'scope, Db: Database + ?Sized, OP, R>(db: &Db, op: OP) -> R
where
OP: FnOnce(&Scope<'scope, '_, Db>) -> R + Send,
R: Send,
{
rayon::in_place_scope(move |s| op(&Scope { db, base: s }))
}

pub fn join<A, B, RA, RB, Db: Database + ?Sized>(db: &Db, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&Db) -> RA + Send,
B: FnOnce(&Db) -> RB + Send,
RA: Send,
RB: Send,
{
// we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get
// moved to another thread before the closure is executed
let db_a = db.fork_db();
let db_b = db.fork_db();
rayon::join(
move || a(db_a.as_view::<Db>()),
move || b(db_b.as_view::<Db>()),
)
}
2 changes: 2 additions & 0 deletions tests/parallel/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ mod cycle_ab_peeping_c;
mod cycle_nested_three_threads;
mod cycle_panic;
mod parallel_cancellation;
mod parallel_join;
mod parallel_map;
mod parallel_scope;
mod signal;
105 changes: 105 additions & 0 deletions tests/parallel/parallel_join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#![cfg(feature = "rayon")]
// test for rayon-like join interactions.

use salsa::Cancelled;
use salsa::Setter;

use crate::setup::Knobs;
use crate::setup::KnobsDatabase;

#[salsa::input]
struct ParallelInput {
a: u32,
b: u32,
}

#[salsa::tracked]
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) {
salsa::join(db, |db| input.a(db) + 1, |db| input.b(db) - 1)
}

#[salsa::tracked]
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) {
db.signal(1);
salsa::join(
db,
|db| {
db.wait_for(2);
input.a(db) + dummy(db)
},
|db| {
db.wait_for(2);
input.b(db) + dummy(db)
},
)
}

#[salsa::tracked]
fn dummy(_db: &dyn KnobsDatabase) -> u32 {
panic!("should never get here!")
}

#[test]
#[cfg_attr(miri, ignore)]
fn execute() {
let db = salsa::DatabaseImpl::new();

let input = ParallelInput::new(&db, 10, 20);

tracked_fn(&db, input);
}

// we expect this to panic, as `salsa::par_map` needs to be called from a query.
#[test]
#[cfg_attr(miri, ignore)]
#[should_panic]
fn direct_calls_panic() {
let db = salsa::DatabaseImpl::new();

let input = ParallelInput::new(&db, 10, 20);
let (_, _) = salsa::join(&db, |db| input.a(db) + 1, |db| input.b(db) - 1);
}

// Cancellation signalling test
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1
// | wait for stage 1
// signal stage 1 set input, triggers cancellation
// wait for stage 2 (blocks) triggering cancellation sends stage 2
// |
// (unblocked)
// dummy
// panics

#[test]
#[cfg_attr(miri, ignore)]
fn execute_cancellation() {
let mut db = Knobs::default();

let input = ParallelInput::new(&db, 10, 20);

let thread_a = std::thread::spawn({
let db = db.clone();
move || a1(&db, input)
});

db.signal_on_did_cancel(2);
input.set_a(&mut db).to(30);

// Assert thread A was cancelled
let cancelled = thread_a
.join()
.unwrap_err()
.downcast::<Cancelled>()
.unwrap();

// and inspect the output
expect_test::expect![[r#"
PendingWrite
"#]]
.assert_debug_eq(&cancelled);
}
2 changes: 1 addition & 1 deletion tests/parallel/parallel_map.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![cfg(feature = "rayon")]
// test for rayon interactions.
// test for rayon-like parallel map interactions.

use salsa::Cancelled;
use salsa::Setter;
Expand Down
116 changes: 116 additions & 0 deletions tests/parallel/parallel_scope.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#![cfg(feature = "rayon")]
// test for rayon-like scope interactions.

use salsa::Cancelled;
use salsa::Setter;

use crate::setup::Knobs;
use crate::setup::KnobsDatabase;

#[salsa::input]
struct ParallelInput {
a: u32,
b: u32,
}

#[salsa::tracked]
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) {
let mut a = None;
let mut b = None;
salsa::scope(db, |scope| {
scope.spawn(|scope| a = Some(input.a(scope.db()) + 1));
scope.spawn(|scope| b = Some(input.b(scope.db()) + 1));
});
(a.unwrap(), b.unwrap())
}

#[salsa::tracked]
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) {
db.signal(1);
let mut a = None;
let mut b = None;
salsa::scope(db, |scope| {
scope.spawn(|scope| {
scope.db().wait_for(2);
a = Some(input.a(scope.db()) + 1)
});
scope.spawn(|scope| {
scope.db().wait_for(2);
b = Some(input.b(scope.db()) + 1)
});
});
(a.unwrap(), b.unwrap())
}

#[salsa::tracked]
fn dummy(_db: &dyn KnobsDatabase) -> u32 {
panic!("should never get here!")
}

#[test]
#[cfg_attr(miri, ignore)]
fn execute() {
let db = salsa::DatabaseImpl::new();

let input = ParallelInput::new(&db, 10, 20);

tracked_fn(&db, input);
}

// we expect this to panic, as `salsa::par_map` needs to be called from a query.
#[test]
#[cfg_attr(miri, ignore)]
#[should_panic]
fn direct_calls_panic() {
let db = salsa::DatabaseImpl::new();

let input = ParallelInput::new(&db, 10, 20);
salsa::scope(&db, |scope| {
scope.spawn(|scope| _ = input.a(scope.db()) + 1);
scope.spawn(|scope| _ = input.b(scope.db()) + 1);
});
}

// Cancellation signalling test
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1
// | wait for stage 1
// signal stage 1 set input, triggers cancellation
// wait for stage 2 (blocks) triggering cancellation sends stage 2
// |
// (unblocked)
// dummy
// panics

#[test]
#[cfg_attr(miri, ignore)]
fn execute_cancellation() {
let mut db = Knobs::default();

let input = ParallelInput::new(&db, 10, 20);

let thread_a = std::thread::spawn({
let db = db.clone();
move || a1(&db, input)
});

db.signal_on_did_cancel(2);
input.set_a(&mut db).to(30);

// Assert thread A was cancelled
let cancelled = thread_a
.join()
.unwrap_err()
.downcast::<Cancelled>()
.unwrap();

// and inspect the output
expect_test::expect![[r#"
PendingWrite
"#]]
.assert_debug_eq(&cancelled);
}
2 changes: 1 addition & 1 deletion tests/parallel/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::signal::Signal;
/// a certain behavior.
#[salsa::db]
pub(crate) trait KnobsDatabase: Database {
/// Signal that we are entering stage 1.
/// Signal that we are entering stage `stage`.
fn signal(&self, stage: usize);

/// Wait until we reach stage `stage` (no-op if we have already reached that stage).
Expand Down