diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ee9aa06b8..1defab25b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -91,7 +91,7 @@ jobs: - name: Test with Miri run: cargo miri nextest run --no-fail-fast --tests env: - MIRIFLAGS: -Zmiri-disable-isolation -Zmiri-retag-fields + MIRIFLAGS: -Zmiri-disable-isolation - name: Run examples with Miri run: cargo miri run --example calc diff --git a/src/attach.rs b/src/attach.rs index 973da8959..072b82238 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -40,6 +40,9 @@ impl Attached { Db: ?Sized + Database, { struct DbGuard<'s> { + /// The database that *we* attached on scope entry. + /// + /// `None` if one was already attached by a parent scope. state: Option<&'s Attached>, } @@ -47,6 +50,7 @@ impl Attached { #[inline] fn new(attached: &'s Attached, db: &dyn Database) -> Self { match attached.database.get() { + // A database is already attached, make sure it's the same as the new one. Some(current_db) => { let new_db = NonNull::from(db); if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) { @@ -54,8 +58,8 @@ impl Attached { } Self { state: None } } + // No database is attached, attach the new one. None => { - // Otherwise, set the database. attached.database.set(Some(NonNull::from(db))); Self { state: Some(attached), @@ -70,7 +74,10 @@ impl Attached { fn drop(&mut self) { // Reset database to null if we did anything in `DbGuard::new`. if let Some(attached) = self.state { - attached.database.set(None); + if let Some(prev) = attached.database.replace(None) { + // SAFETY: `prev` is a valid pointer to a database. + unsafe { prev.as_ref().zalsa_local().uncancel() }; + } } } } @@ -85,17 +92,45 @@ impl Attached { Db: ?Sized + Database, { struct DbGuard<'s> { - state: &'s Attached, + /// The database that *we* attached on scope entry. + /// + /// `None` if one was already attached by a parent scope. + state: Option<&'s Attached>, + /// The previously attached database that we replaced, if any. + /// + /// We need to make sure to rollback and activate it again when we exit the scope. prev: Option>, } impl<'s> DbGuard<'s> { #[inline] fn new(attached: &'s Attached, db: &dyn Database) -> Self { - let prev = attached.database.replace(Some(NonNull::from(db))); - Self { - state: attached, - prev, + let db = NonNull::from(db); + match attached.database.replace(Some(db)) { + // A database was already attached by a parent scope. + Some(prev) => { + if std::ptr::eq(db.as_ptr(), prev.as_ptr()) { + // and it was the same as ours, so we did not change anything. + Self { + state: None, + prev: None, + } + } else { + // and it was the a different one from ours, record the state changes. + Self { + state: Some(attached), + prev: Some(prev), + } + } + } + // No database is attached, attach the new one. + None => { + attached.database.set(Some(db)); + Self { + state: Some(attached), + prev: None, + } + } } } } @@ -103,7 +138,13 @@ impl Attached { impl Drop for DbGuard<'_> { #[inline] fn drop(&mut self) { - self.state.database.set(self.prev); + // Reset database to null if we did anything in `DbGuard::new`. + if let Some(attached) = self.state { + if let Some(prev) = attached.database.replace(self.prev) { + // SAFETY: `prev` is a valid pointer to a database. + unsafe { prev.as_ref().zalsa_local().uncancel() }; + } + } } } diff --git a/src/cancelled.rs b/src/cancelled.rs index 3c31bae5a..e690eac35 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -10,12 +10,13 @@ use std::panic::{self, UnwindSafe}; #[derive(Debug)] #[non_exhaustive] pub enum Cancelled { + /// The query was operating but the local database execution has been cancelled. + Local, + /// The query was operating on revision R, but there is a pending write to move to revision R+1. - #[non_exhaustive] PendingWrite, /// The query was blocked on another thread, and that thread panicked. - #[non_exhaustive] PropagatedPanic, } @@ -45,6 +46,7 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let why = match self { + Cancelled::Local => "local cancellation request", Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/src/database.rs b/src/database.rs index 0df83b03b..0831fd5bf 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,6 +3,7 @@ use std::ptr::NonNull; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; +use crate::zalsa_local::CancellationToken; use crate::{Durability, Revision}; #[derive(Copy, Clone)] @@ -59,7 +60,7 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { zalsa_mut.runtime_mut().report_tracked_write(durability); } - /// This method triggers cancellation. + /// This method cancels all outstanding computations. /// If you invoke it while a snapshot exists, it /// will block until that snapshot is dropped -- if that snapshot /// is owned by the current thread, this could trigger deadlock. @@ -67,6 +68,11 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { let _ = self.zalsa_mut(); } + /// Retrieves a [`CancellationToken`] for the current database handle. + fn cancellation_token(&self) -> CancellationToken { + self.zalsa_local().cancellation_token() + } + /// Reports that the query depends on some state unknown to salsa. /// /// Queries which report untracked reads will be re-executed in the next diff --git a/src/function/execute.rs b/src/function/execute.rs index a4dbe4986..11d8f0fdd 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -38,7 +38,6 @@ where &'db self, db: &'db C::DbView, mut claim_guard: ClaimGuard<'db>, - zalsa_local: &'db ZalsaLocal, opt_old_memo: Option<&Memo<'db, C>>, ) -> Option<&'db Memo<'db, C>> { let database_key_index = claim_guard.database_key_index(); @@ -60,7 +59,9 @@ where let (new_value, active_query) = Self::execute_query( db, zalsa, - zalsa_local.push_query(database_key_index, IterationCount::initial()), + claim_guard + .zalsa_local() + .push_query(database_key_index, IterationCount::initial()), opt_old_memo, ); (new_value, active_query.pop()) @@ -69,7 +70,9 @@ where let (mut new_value, active_query) = Self::execute_query( db, zalsa, - zalsa_local.push_query(database_key_index, IterationCount::initial()), + claim_guard + .zalsa_local() + .push_query(database_key_index, IterationCount::initial()), opt_old_memo, ); @@ -97,8 +100,9 @@ where // Cycle participants that don't have a fallback will be discarded in // `validate_provisional()`. let cycle_heads = std::mem::take(cycle_heads); - let active_query = - zalsa_local.push_query(database_key_index, IterationCount::initial()); + let active_query = claim_guard + .zalsa_local() + .push_query(database_key_index, IterationCount::initial()); new_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); completed_query = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. @@ -109,13 +113,18 @@ where (new_value, completed_query) } - CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate( - db, - opt_old_memo, - &mut claim_guard, - zalsa_local, - memo_ingredient_index, - ), + CycleRecoveryStrategy::Fixpoint => { + let zalsa_local = claim_guard.zalsa_local(); + let was_disabled = zalsa_local.set_cancellation_disabled(true); + let res = self.execute_maybe_iterate( + db, + opt_old_memo, + &mut claim_guard, + memo_ingredient_index, + ); + zalsa_local.set_cancellation_disabled(was_disabled); + res + } }; if let Some(old_memo) = opt_old_memo { @@ -158,7 +167,6 @@ where db: &'db C::DbView, opt_old_memo: Option<&Memo<'db, C>>, claim_guard: &mut ClaimGuard<'db>, - zalsa_local: &'db ZalsaLocal, memo_ingredient_index: MemoIngredientIndex, ) -> (C::Output<'db>, CompletedQuery) { claim_guard.set_release_mode(ReleaseMode::Default); @@ -205,7 +213,9 @@ where PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); let (new_value, completed_query) = loop { - let active_query = zalsa_local.push_query(database_key_index, iteration_count); + let active_query = claim_guard + .zalsa_local() + .push_query(database_key_index, iteration_count); // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next @@ -248,7 +258,12 @@ where iteration_count, ); - let outer_cycle = outer_cycle(zalsa, zalsa_local, &cycle_heads, database_key_index); + let outer_cycle = outer_cycle( + zalsa, + claim_guard.zalsa_local(), + &cycle_heads, + database_key_index, + ); // Did the new result we got depend on our own provisional value, in a cycle? // If not, return because this query is not a cycle head. diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 588b08bb1..48eee089f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -105,10 +105,15 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) { + let claim_guard = match self + .sync_table + .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) + { ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { - blocked_on.block_on(zalsa); + if !blocked_on.block_on(zalsa) { + return None; + } if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); @@ -176,7 +181,7 @@ where } } - self.execute(db, claim_guard, zalsa_local, opt_old_memo) + self.execute(db, claim_guard, opt_old_memo) } #[cold] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 6ea17b13f..f25d955e1 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -141,23 +141,24 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let claim_guard = match self - .sync_table - .try_claim(zalsa, key_index, Reentrancy::Deny) - { - ClaimResult::Claimed(guard) => guard, - ClaimResult::Running(blocked_on) => { - blocked_on.block_on(zalsa); - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.maybe_changed_after_cold_cycle( - zalsa_local, - database_key_index, - cycle_heads, - )) - } - }; + let claim_guard = + match self + .sync_table + .try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny) + { + ClaimResult::Claimed(guard) => guard, + ClaimResult::Running(blocked_on) => { + _ = blocked_on.block_on(zalsa); + return None; + } + ClaimResult::Cycle { .. } => { + return Some(self.maybe_changed_after_cold_cycle( + zalsa_local, + database_key_index, + cycle_heads, + )) + } + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) else { @@ -228,7 +229,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo))?; + let memo = self.execute(db, claim_guard, Some(old_memo))?; let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. diff --git a/src/function/memo.rs b/src/function/memo.rs index 234829cb1..9ebae6c9e 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -180,7 +180,14 @@ impl<'db, C: Configuration> Memo<'db, C> { } TryClaimHeadsResult::Running(running) => { all_cycles = false; - running.block_on(zalsa); + if !running.block_on(zalsa) { + // We cannot really handle local cancellations reliably here + // so we treat it as a general cancellation / panic. + // + // We shouldn't hit this though as we disable local cancellation + // in cycles. + crate::Cancelled::PropagatedPanic.throw(); + } } } } @@ -501,7 +508,7 @@ mod _memory_usage { use std::any::TypeId; use std::num::NonZeroUsize; - // Memo's are stored a lot, make sure their size is doesn't randomly increase. + // Memo's are stored a lot, make sure their size doesn't randomly increase. const _: [(); std::mem::size_of::>()] = [(); std::mem::size_of::<[usize; 6]>()]; diff --git a/src/function/sync.rs b/src/function/sync.rs index c9a74a307..f01932a37 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -2,6 +2,7 @@ use rustc_hash::FxHashMap; use std::collections::hash_map::OccupiedEntry; use crate::key::DatabaseKeyIndex; +use crate::plumbing::ZalsaLocal; use crate::runtime::{ BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult, }; @@ -21,6 +22,8 @@ pub(crate) struct SyncTable { } pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { + /// Successfully claimed the query. + Claimed(Guard), /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. @@ -30,8 +33,6 @@ pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { /// [`SyncTable::try_claim`] with [`Reentrant::Allow`]. inner: bool, }, - /// Successfully claimed the query. - Claimed(Guard), } pub(crate) struct SyncState { @@ -68,6 +69,7 @@ impl SyncTable { pub(crate) fn try_claim<'me>( &'me self, zalsa: &'me Zalsa, + zalsa_local: &'me ZalsaLocal, key_index: Id, reentrant: Reentrancy, ) -> ClaimResult<'me> { @@ -77,7 +79,12 @@ impl SyncTable { let id = match occupied_entry.get().id { SyncOwner::Thread(id) => id, SyncOwner::Transferred => { - return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) { + return match self.try_claim_transferred( + zalsa, + zalsa_local, + occupied_entry, + reentrant, + ) { Ok(claimed) => claimed, Err(other_thread) => match other_thread.block(write) { BlockResult::Cycle => ClaimResult::Cycle { inner: false }, @@ -115,6 +122,7 @@ impl SyncTable { ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::Default, }) @@ -172,6 +180,7 @@ impl SyncTable { fn try_claim_transferred<'me>( &'me self, zalsa: &'me Zalsa, + zalsa_local: &'me ZalsaLocal, mut entry: OccupiedEntry, reentrant: Reentrancy, ) -> Result, Box>> { @@ -195,6 +204,7 @@ impl SyncTable { Ok(ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::SelfOnly, })) @@ -214,6 +224,7 @@ impl SyncTable { Ok(ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::Default, })) @@ -277,7 +288,7 @@ pub enum SyncOwner { /// The query's lock ownership has been transferred to another query. /// E.g. if `a` transfers its ownership to `b`, then only the thread in the critical path - /// to complete b` can claim `a` (in most instances, only the thread owning `b` can claim `a`). + /// to complete `b` can claim `a` (in most instances, only the thread owning `b` can claim `a`). /// /// The thread owning `a` is stored in the `DependencyGraph`. /// @@ -295,6 +306,7 @@ pub(crate) struct ClaimGuard<'me> { zalsa: &'me Zalsa, sync_table: &'me SyncTable, mode: ReleaseMode, + zalsa_local: &'me ZalsaLocal, } impl<'me> ClaimGuard<'me> { @@ -302,6 +314,10 @@ impl<'me> ClaimGuard<'me> { self.zalsa } + pub(crate) fn zalsa_local(&self) -> &'me ZalsaLocal { + self.zalsa_local + } + pub(crate) const fn database_key_index(&self) -> DatabaseKeyIndex { DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index) } @@ -315,12 +331,17 @@ impl<'me> ClaimGuard<'me> { fn release_panicking(&self) { let mut syncs = self.sync_table.syncs.lock(); let state = syncs.remove(&self.key_index).expect("key claimed twice?"); + let result = if self.zalsa_local.should_trigger_local_cancellation() { + WaitResult::Cancelled + } else { + WaitResult::Panicked + }; tracing::debug!( - "Release claim on {:?} due to panic", - self.database_key_index() + "Release claim on {:?} due to {:?}", + self.database_key_index(), + result ); - - self.release(state, WaitResult::Panicked); + self.release(state, result); } #[inline(always)] diff --git a/src/interned.rs b/src/interned.rs index 544a8d0ee..afd15f71e 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -850,6 +850,7 @@ pub struct StructEntry<'db, C> where C: Configuration, { + #[allow(dead_code)] value: &'db Value, key: DatabaseKeyIndex, } diff --git a/src/lib.rs b/src/lib.rs index f90fce338..d2324156f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,7 @@ pub use self::runtime::Runtime; pub use self::storage::{Storage, StorageHandle}; pub use self::update::Update; pub use self::zalsa::IngredientIndex; +pub use self::zalsa_local::CancellationToken; pub use crate::attach::{attach, attach_allow_change, with_attached_database}; pub mod prelude { diff --git a/src/runtime.rs b/src/runtime.rs index 48caf53ec..5b36bf205 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -13,11 +13,11 @@ mod dependency_graph; #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Runtime { - /// Set to true when the current revision has been canceled. + /// Set to true when the current revision has been cancelled. /// This is done when we an input is being changed. The flag /// is set back to false once the input has been changed. #[cfg_attr(feature = "persistence", serde(skip))] - revision_canceled: AtomicBool, + revision_cancelled: AtomicBool, /// Stores the "last change" revision for values of each duration. /// This vector is always of length at least 1 (for Durability 0) @@ -44,6 +44,7 @@ pub struct Runtime { pub(super) enum WaitResult { Completed, Panicked, + Cancelled, } #[derive(Debug)] @@ -121,7 +122,14 @@ struct BlockedOnInner<'me> { impl Running<'_> { /// Blocks on the other thread to complete the computation. - pub(crate) fn block_on(self, zalsa: &Zalsa) { + /// + /// Returns `true` if the computation was successful, and `false` if the other thread was locally cancelled. + /// + /// # Panics + /// + /// If the other thread panics, this function will panic as well. + #[must_use] + pub(crate) fn block_on(self, zalsa: &Zalsa) -> bool { let BlockedOnInner { dg, query_mutex_guard, @@ -151,7 +159,8 @@ impl Running<'_> { // by the other thread and responded to appropriately. Cancelled::PropagatedPanic.throw() } - WaitResult::Completed => {} + WaitResult::Cancelled => false, + WaitResult::Completed => true, } } } @@ -183,7 +192,7 @@ impl Default for Runtime { fn default() -> Self { Runtime { revisions: [Revision::start(); Durability::LEN], - revision_canceled: Default::default(), + revision_cancelled: Default::default(), dependency_graph: Default::default(), table: Default::default(), } @@ -194,7 +203,7 @@ impl std::fmt::Debug for Runtime { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt.debug_struct("Runtime") .field("revisions", &self.revisions) - .field("revision_canceled", &self.revision_canceled) + .field("revision_cancelled", &self.revision_cancelled) .field("dependency_graph", &self.dependency_graph) .finish() } @@ -227,16 +236,16 @@ impl Runtime { } pub(crate) fn load_cancellation_flag(&self) -> bool { - self.revision_canceled.load(Ordering::Acquire) + self.revision_cancelled.load(Ordering::Acquire) } pub(crate) fn set_cancellation_flag(&self) { crate::tracing::trace!("set_cancellation_flag"); - self.revision_canceled.store(true, Ordering::Release); + self.revision_cancelled.store(true, Ordering::Release); } pub(crate) fn reset_cancellation_flag(&mut self) { - *self.revision_canceled.get_mut() = false; + *self.revision_cancelled.get_mut() = false; } /// Returns the [`Table`] used to store the value of salsa structs diff --git a/src/storage.rs b/src/storage.rs index b6c532709..c2a7029ee 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -136,12 +136,15 @@ impl Storage { .record_unfilled_pages(self.handle.zalsa_impl.table()); let Self { handle, - zalsa_local: _, - } = &self; + zalsa_local, + } = &mut self; // Avoid rust's annoying destructure prevention rules for `Drop` types // SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure // above makes sure we won't forget to take into account newly added fields. let handle = unsafe { std::ptr::read(handle) }; + // SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure + // above makes sure we won't forget to take into account newly added fields. + unsafe { std::ptr::drop_in_place(zalsa_local) }; std::mem::forget::(self); handle } diff --git a/src/zalsa.rs b/src/zalsa.rs index f2f520b49..d550b4c1f 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -296,8 +296,11 @@ impl Zalsa { #[inline] pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) { self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation)); + if zalsa_local.should_trigger_local_cancellation() { + zalsa_local.unwind_cancelled(); + } if self.runtime().load_cancellation_flag() { - zalsa_local.unwind_cancelled(self.current_revision()); + zalsa_local.unwind_pending_write(); } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 8f0239e56..d60582eab 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -3,6 +3,8 @@ use std::fmt; use std::fmt::Formatter; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; use rustc_hash::FxHashMap; use thin_vec::ThinVec; @@ -39,6 +41,45 @@ pub struct ZalsaLocal { /// Stores the most recent page for a given ingredient. /// This is thread-local to avoid contention. most_recent_pages: UnsafeCell>, + + cancelled: CancellationToken, +} + +/// A cancellation token that can be used to cancel a query computation for a specific local `Database`. +#[derive(Default, Clone, Debug)] +pub struct CancellationToken(Arc); + +impl CancellationToken { + const CANCELLED_MASK: u8 = 0b01; + const DISABLED_MASK: u8 = 0b10; + + /// Inform the database to cancel the current query computation. + pub fn cancel(&self) { + self.0.fetch_or(Self::CANCELLED_MASK, Ordering::Relaxed); + } + + /// Check if the query computation has been requested to be cancelled. + pub fn is_cancelled(&self) -> bool { + self.0.load(Ordering::Relaxed) & Self::CANCELLED_MASK != 0 + } + + #[inline] + fn set_cancellation_disabled(&self, disabled: bool) -> bool { + let previous_disabled_bit = if disabled { + self.0.fetch_or(Self::DISABLED_MASK, Ordering::Relaxed) + } else { + self.0.fetch_and(!Self::DISABLED_MASK, Ordering::Relaxed) + }; + previous_disabled_bit & Self::DISABLED_MASK != 0 + } + + fn should_trigger_local_cancellation(&self) -> bool { + self.0.load(Ordering::Relaxed) == Self::CANCELLED_MASK + } + + fn reset(&self) { + self.0.store(0, Ordering::Relaxed); + } } impl ZalsaLocal { @@ -46,6 +87,7 @@ impl ZalsaLocal { ZalsaLocal { query_stack: RefCell::new(QueryStack::default()), most_recent_pages: UnsafeCell::new(FxHashMap::default()), + cancelled: CancellationToken::default(), } } @@ -401,12 +443,35 @@ impl ZalsaLocal { } } + #[inline] + pub(crate) fn cancellation_token(&self) -> CancellationToken { + self.cancelled.clone() + } + + #[inline] + pub(crate) fn uncancel(&self) { + self.cancelled.reset(); + } + + #[inline] + pub fn should_trigger_local_cancellation(&self) -> bool { + self.cancelled.should_trigger_local_cancellation() + } + #[cold] - pub(crate) fn unwind_cancelled(&self, current_revision: Revision) { - // Why is this reporting an untracked read? We do not store the query revisions on unwind do we? - self.report_untracked_read(current_revision); + pub(crate) fn unwind_pending_write(&self) { Cancelled::PendingWrite.throw(); } + + #[cold] + pub(crate) fn unwind_cancelled(&self) { + Cancelled::Local.throw(); + } + + #[inline] + pub(crate) fn set_cancellation_disabled(&self, was_disabled: bool) -> bool { + self.cancelled.set_cancellation_disabled(was_disabled) + } } // Okay to implement as `ZalsaLocal`` is !Sync diff --git a/tests/cancellation_token.rs b/tests/cancellation_token.rs new file mode 100644 index 000000000..f6a14930a --- /dev/null +++ b/tests/cancellation_token.rs @@ -0,0 +1,67 @@ +#![cfg(feature = "inventory")] +//! Test that `DeriveWithDb` is correctly derived. + +mod common; + +use std::{sync::Barrier, thread}; + +use expect_test::expect; +use salsa::{Cancelled, Database}; + +use crate::common::LogDatabase; + +#[salsa::input(debug)] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +fn a(db: &dyn Database, input: MyInput) -> u32 { + BARRIER.wait(); + BARRIER2.wait(); + b(db, input) +} +#[salsa::tracked] +fn b(db: &dyn Database, input: MyInput) -> u32 { + input.field(db) +} + +static BARRIER: Barrier = Barrier::new(2); +static BARRIER2: Barrier = Barrier::new(2); + +#[test] +fn cancellation_token() { + let db = common::EventLoggerDatabase::default(); + let token = db.cancellation_token(); + let input = MyInput::new(&db, 22); + let res = Cancelled::catch(|| { + thread::scope(|s| { + s.spawn(|| { + BARRIER.wait(); + token.cancel(); + BARRIER2.wait(); + }); + a(&db, input) + }) + }); + assert!(matches!(res, Err(Cancelled::Local)), "{res:?}"); + drop(res); + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: a(Id(0)) }", + "WillCheckCancellation", + ]"#]]); + thread::spawn(|| { + BARRIER.wait(); + BARRIER2.wait(); + }); + a(&db, input); + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: a(Id(0)) }", + "WillCheckCancellation", + "WillExecute { database_key: b(Id(0)) }", + ]"#]]); +} diff --git a/tests/interned-revisions.rs b/tests/interned-revisions.rs index bef1db61c..02d7d8112 100644 --- a/tests/interned-revisions.rs +++ b/tests/interned-revisions.rs @@ -155,8 +155,8 @@ fn test_immortal() { // Modify the input to bump the revision and intern a new value. // - // No values should ever be reused with `durability = usize::MAX`. - for i in 1..100 { + // No values should ever be reused with `revisions = usize::MAX`. + for i in 1..if cfg!(miri) { 50 } else { 1000 } { input.set_field1(&mut db).to(i); let result = function(&db, input); assert_eq!(result.field1(&db).0, i); diff --git a/tests/parallel/cancellation_token_cycle_nested.rs b/tests/parallel/cancellation_token_cycle_nested.rs new file mode 100644 index 000000000..b1e8fbb62 --- /dev/null +++ b/tests/parallel/cancellation_token_cycle_nested.rs @@ -0,0 +1,146 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation with deeply nested cycles across multiple threads. +//! +//! These tests verify that local cancellation is disabled during cycle iteration, +//! allowing multi-threaded cycles to complete successfully before cancellation +//! can take effect. +use salsa::Database; + +use crate::setup::{Knobs, KnobsDatabase}; +use crate::sync::thread; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +#[salsa::tracked(cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let c_value = query_c(db); + CycleValue(c_value.0 + 1).min(MAX) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + let d_value = query_d(db); + let e_value = query_e(db); + let b_value = query_b(db); + let a_value = query_a(db); + CycleValue(d_value.0.max(e_value.0).max(b_value.0).max(a_value.0)) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_d(db: &dyn KnobsDatabase) -> CycleValue { + query_c(db) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_e(db: &dyn KnobsDatabase) -> CycleValue { + query_c(db) +} + +#[salsa::tracked] +fn query_f(db: &dyn KnobsDatabase) -> CycleValue { + let c = query_c(db); + // this should trigger cancellation again + query_h(db); + c +} + +#[salsa::tracked] +fn query_h(db: &dyn KnobsDatabase) { + _ = db; +} + +fn initial(db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { + db.signal(1); + db.wait_for(6); + MIN +} + +/// Test that a multi-threaded cycle completes successfully even when +/// cancellation is requested during the cycle. +/// +/// This test is similar to cycle_nested_deep but adds cancellation during +/// the cycle to verify that cancellation is properly deferred. +#[test] +fn multi_threaded_cycle_completes_despite_cancellation() { + let db = Knobs::default(); + let db_t1 = db.clone(); + let db_t2 = db.clone(); + let db_t3 = db.clone(); + let db_t4 = db.clone(); + let db_t5 = db.clone(); + let db_signaler = db; + + let token_t1 = db_t1.cancellation_token(); + let token_t2 = db_t2.cancellation_token(); + let token_t3 = db_t3.cancellation_token(); + let token_t5 = db_t5.cancellation_token(); + + // Thread 1: Runs the main cycle, will have cancellation requested during it + let t1 = thread::spawn(move || query_a(&db_t1)); + + // Wait for t1 to start the cycle + db_signaler.wait_for(1); + + // Spawn t2 and wait for it to block on the cycle + db_signaler.signal_on_will_block(2); + let t2 = thread::spawn(move || query_b(&db_t2)); + db_signaler.wait_for(2); + + // Spawn t3 and wait for it to block on the cycle + db_signaler.signal_on_will_block(3); + let t3 = thread::spawn(move || query_d(&db_t3)); + db_signaler.wait_for(3); + + // Spawn t4 - doesn't get cancelled + db_signaler.signal_on_will_block(4); + let t4 = thread::spawn(move || query_e(&db_t4)); + db_signaler.wait_for(4); + + // Spawn t5 - doesn't get cancelled + db_signaler.signal_on_will_block(5); + let t5 = thread::spawn(move || query_f(&db_t5)); + db_signaler.wait_for(5); + + // Request cancellation while t2 and t3 are blocked on the cycle + // This should be deferred until after the cycle completes + token_t1.cancel(); + token_t2.cancel(); + token_t3.cancel(); + token_t5.cancel(); + + // Let t1 continue - the cycle should still complete because + // cancellation is disabled during fixpoint iteration + db_signaler.signal(6); + + // All threads should complete successfully + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + let r_t4 = t4.join().unwrap(); + + let r_t5 = t5.join().unwrap_err(); + + // All should get MAX because cycles defer cancellation + assert_eq!(r_t1, MAX, "t1 should get MAX"); + assert_eq!(r_t2, MAX, "t2 should get MAX"); + assert_eq!(r_t3, MAX, "t3 should get MAX"); + assert_eq!(r_t4, MAX, "t4 should get MAX"); + assert!( + matches!( + *r_t5.downcast::().unwrap(), + salsa::Cancelled::Local + ), + "t5 should be cancelled as its blocked on the cycle, not participating in it and calling an uncomputed query after" + ); +} diff --git a/tests/parallel/cancellation_token_multi_blocked.rs b/tests/parallel/cancellation_token_multi_blocked.rs new file mode 100644 index 000000000..3e3caf231 --- /dev/null +++ b/tests/parallel/cancellation_token_multi_blocked.rs @@ -0,0 +1,86 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation when multiple queries are blocked on the cancelled thread. +//! +//! This test verifies that: +//! 1. When a thread is cancelled, blocked threads recompute rather than propagate cancellation +//! 2. The final result is correctly computed by the remaining threads +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::tracked] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + query_b(db) +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + // Signal that t1 has started computing query_b + db.signal(1); + // Wait for t2 and t3 to block on us + db.wait_for(3); + // Wait for cancellation to happen + db.wait_for(4); + query_c(db) +} + +#[salsa::tracked] +fn query_c(_db: &dyn KnobsDatabase) -> u32 { + 42 +} + +/// Test that when a thread is cancelled, other blocked threads successfully +/// recompute the query and get the correct result. +#[test] +fn multiple_threads_blocked_on_cancelled() { + let db = Knobs::default(); + let db2 = db.clone(); + let db3 = db.clone(); + let db_signaler = db.clone(); + let token = db.cancellation_token(); + + // Thread 1: Starts computing query_a -> query_b, will be cancelled + let t1 = std::thread::spawn(move || query_a(&db)); + + // Wait for t1 to start query_b + db_signaler.wait_for(1); + + // Thread 2: Will block on query_a (which is blocked on query_b) + db2.signal_on_will_block(2); + let t2 = std::thread::spawn(move || query_a(&db2)); + + // Wait for t2 to block + db_signaler.wait_for(2); + + // Thread 3: Also blocks on query_a + db3.signal_on_will_block(3); + let t3 = std::thread::spawn(move || query_a(&db3)); + + // Wait for t3 to block + db_signaler.wait_for(3); + + // Now cancel t1 + token.cancel(); + + // Let t1 continue and get cancelled + db_signaler.signal(4); + + // Collect results + let r1 = t1.join(); + let r2 = t2.join(); + let r3 = t3.join(); + + // t1 should have been cancelled + let r1_cancelled = r1.unwrap_err().downcast::().map(|c| *c); + assert!( + matches!(r1_cancelled, Ok(Cancelled::Local)), + "t1 should be locally cancelled, got: {:?}", + r1_cancelled + ); + + // t2 and t3 should both succeed with the correct value + assert_eq!(r2.unwrap(), 42, "t2 should compute the correct result"); + assert_eq!(r3.unwrap(), 42, "t3 should compute the correct result"); +} diff --git a/tests/parallel/cancellation_token_recomputes.rs b/tests/parallel/cancellation_token_recomputes.rs new file mode 100644 index 000000000..0bbb67ef0 --- /dev/null +++ b/tests/parallel/cancellation_token_recomputes.rs @@ -0,0 +1,43 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation when another query is blocked on the cancelled thread. +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::tracked] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + query_b(db) +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + db.signal(1); + db.wait_for(3); + query_c(db) +} + +#[salsa::tracked] +fn query_c(_db: &dyn KnobsDatabase) -> u32 { + 1 +} +#[test] +fn execute() { + let db = Knobs::default(); + let db2 = db.clone(); + let db_signaler = db.clone(); + let token = db.cancellation_token(); + + let t1 = std::thread::spawn(move || query_a(&db)); + db_signaler.wait_for(1); + db2.signal_on_will_block(2); + let t2 = std::thread::spawn(move || query_a(&db2)); + db_signaler.wait_for(2); + token.cancel(); + db_signaler.signal(3); + let (r1, r2) = (t1.join(), t2.join()); + let r1 = *r1.unwrap_err().downcast::().unwrap(); + assert!(matches!(r1, Cancelled::Local), "{r1:?}"); + assert_eq!(r2.unwrap(), 1); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 1062d4899..399eaa7da 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,9 @@ mod setup; mod signal; +mod cancellation_token_cycle_nested; +mod cancellation_token_multi_blocked; +mod cancellation_token_recomputes; mod cycle_a_t1_b_t2; mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c;