From 5be84e608fa27e00521cb36fdee250d46607fb70 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 17 Oct 2025 19:01:29 +0200 Subject: [PATCH 1/7] Introduce a `CancellationToken` for cancelling specific computations --- src/attach.rs | 42 ++++++++++-- src/cancelled.rs | 6 +- src/database.rs | 8 ++- src/function/fetch.rs | 49 ++++++++------ src/function/maybe_changed_after.rs | 33 ++++----- src/function/memo.rs | 5 +- src/function/sync.rs | 37 ++++++++-- src/interned.rs | 1 + src/lib.rs | 1 + src/runtime.rs | 24 ++++--- src/storage.rs | 7 +- src/zalsa.rs | 5 +- src/zalsa_local.rs | 49 +++++++++++++- tests/cancellation_token.rs | 67 +++++++++++++++++++ tests/interned-revisions.rs | 2 +- .../parallel/cancellation_token_recomputes.rs | 43 ++++++++++++ tests/parallel/main.rs | 1 + 17 files changed, 312 insertions(+), 68 deletions(-) create mode 100644 tests/cancellation_token.rs create mode 100644 tests/parallel/cancellation_token_recomputes.rs diff --git a/src/attach.rs b/src/attach.rs index 973da8959..3a6edc93f 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -70,7 +70,10 @@ impl Attached { fn drop(&mut self) { // Reset database to null if we did anything in `DbGuard::new`. if let Some(attached) = self.state { - attached.database.set(None); + if let Some(prev) = attached.database.replace(None) { + // SAFETY: `prev` is a valid pointer to a database. + unsafe { prev.as_ref().zalsa_local().uncancel() }; + } } } } @@ -85,17 +88,36 @@ impl Attached { Db: ?Sized + Database, { struct DbGuard<'s> { - state: &'s Attached, + state: Option<&'s Attached>, prev: Option>, } impl<'s> DbGuard<'s> { #[inline] fn new(attached: &'s Attached, db: &dyn Database) -> Self { - let prev = attached.database.replace(Some(NonNull::from(db))); - Self { - state: attached, - prev, + let db = NonNull::from(db); + match attached.database.replace(Some(db)) { + Some(prev) => { + if std::ptr::eq(db.as_ptr(), prev.as_ptr()) { + Self { + state: None, + prev: None, + } + } else { + Self { + state: Some(attached), + prev: Some(prev), + } + } + } + None => { + // Otherwise, set the database. + attached.database.set(Some(db)); + Self { + state: Some(attached), + prev: None, + } + } } } } @@ -103,7 +125,13 @@ impl Attached { impl Drop for DbGuard<'_> { #[inline] fn drop(&mut self) { - self.state.database.set(self.prev); + // Reset database to null if we did anything in `DbGuard::new`. + if let Some(attached) = self.state { + if let Some(prev) = attached.database.replace(self.prev) { + // SAFETY: `prev` is a valid pointer to a database. + unsafe { prev.as_ref().zalsa_local().uncancel() }; + } + } } } diff --git a/src/cancelled.rs b/src/cancelled.rs index 3c31bae5a..1fa0edc59 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -10,12 +10,13 @@ use std::panic::{self, UnwindSafe}; #[derive(Debug)] #[non_exhaustive] pub enum Cancelled { + /// The query was operating but the local database execution has been cancelled. + Cancelled, + /// The query was operating on revision R, but there is a pending write to move to revision R+1. - #[non_exhaustive] PendingWrite, /// The query was blocked on another thread, and that thread panicked. - #[non_exhaustive] PropagatedPanic, } @@ -45,6 +46,7 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let why = match self { + Cancelled::Cancelled => "canellation request", Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/src/database.rs b/src/database.rs index 0df83b03b..9cb70f917 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,6 +3,7 @@ use std::ptr::NonNull; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; +use crate::zalsa_local::CancellationToken; use crate::{Durability, Revision}; #[derive(Copy, Clone)] @@ -59,7 +60,7 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { zalsa_mut.runtime_mut().report_tracked_write(durability); } - /// This method triggers cancellation. + /// This method cancels all outstanding computations. /// If you invoke it while a snapshot exists, it /// will block until that snapshot is dropped -- if that snapshot /// is owned by the current thread, this could trigger deadlock. @@ -67,6 +68,11 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { let _ = self.zalsa_mut(); } + /// Retrives a [`CancellationToken`] for the current database. + fn cancellation_token(&self) -> CancellationToken { + self.zalsa_local().cancellation_token() + } + /// Reports that the query depends on some state unknown to salsa. /// /// Queries which report untracked reads will be re-executed in the next diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 588b08bb1..313bd43ee 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -105,32 +105,39 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) { - ClaimResult::Claimed(guard) => guard, - ClaimResult::Running(blocked_on) => { - blocked_on.block_on(zalsa); + let claim_guard = loop { + match self + .sync_table + .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) + { + ClaimResult::Claimed(guard) => break guard, + ClaimResult::Running(blocked_on) => { + if !blocked_on.block_on(zalsa) { + continue; + } - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { - let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo { - if memo.value.is_some() { - memo.block_on_heads(zalsa); + if let Some(memo) = memo { + if memo.value.is_some() { + memo.block_on_heads(zalsa); + } } } - } - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.fetch_cold_cycle( - zalsa, - zalsa_local, - db, - id, - database_key_index, - memo_ingredient_index, - )); + return None; + } + ClaimResult::Cycle { .. } => { + return Some(self.fetch_cold_cycle( + zalsa, + zalsa_local, + db, + id, + database_key_index, + memo_ingredient_index, + )); + } } }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 6ea17b13f..fc9d7ac88 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -141,21 +141,24 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let claim_guard = match self - .sync_table - .try_claim(zalsa, key_index, Reentrancy::Deny) - { - ClaimResult::Claimed(guard) => guard, - ClaimResult::Running(blocked_on) => { - blocked_on.block_on(zalsa); - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.maybe_changed_after_cold_cycle( - zalsa_local, - database_key_index, - cycle_heads, - )) + let claim_guard = loop { + match self + .sync_table + .try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny) + { + ClaimResult::Claimed(guard) => break guard, + ClaimResult::Running(blocked_on) => { + if blocked_on.block_on(zalsa) { + return None; + } + } + ClaimResult::Cycle { .. } => { + return Some(self.maybe_changed_after_cold_cycle( + zalsa_local, + database_key_index, + cycle_heads, + )) + } } }; // Load the current memo, if any. diff --git a/src/function/memo.rs b/src/function/memo.rs index 234829cb1..ac259b4b0 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -180,7 +180,10 @@ impl<'db, C: Configuration> Memo<'db, C> { } TryClaimHeadsResult::Running(running) => { all_cycles = false; - running.block_on(zalsa); + if !running.block_on(zalsa) { + // FIXME: Handle cancellation properly? + crate::Cancelled::PropagatedPanic.throw(); + } } } } diff --git a/src/function/sync.rs b/src/function/sync.rs index c9a74a307..edc881ffa 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -2,6 +2,7 @@ use rustc_hash::FxHashMap; use std::collections::hash_map::OccupiedEntry; use crate::key::DatabaseKeyIndex; +use crate::plumbing::ZalsaLocal; use crate::runtime::{ BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult, }; @@ -21,6 +22,8 @@ pub(crate) struct SyncTable { } pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { + /// Successfully claimed the query. + Claimed(Guard), /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. @@ -30,8 +33,6 @@ pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { /// [`SyncTable::try_claim`] with [`Reentrant::Allow`]. inner: bool, }, - /// Successfully claimed the query. - Claimed(Guard), } pub(crate) struct SyncState { @@ -68,6 +69,7 @@ impl SyncTable { pub(crate) fn try_claim<'me>( &'me self, zalsa: &'me Zalsa, + zalsa_local: &'me ZalsaLocal, key_index: Id, reentrant: Reentrancy, ) -> ClaimResult<'me> { @@ -77,7 +79,12 @@ impl SyncTable { let id = match occupied_entry.get().id { SyncOwner::Thread(id) => id, SyncOwner::Transferred => { - return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) { + return match self.try_claim_transferred( + zalsa, + zalsa_local, + occupied_entry, + reentrant, + ) { Ok(claimed) => claimed, Err(other_thread) => match other_thread.block(write) { BlockResult::Cycle => ClaimResult::Cycle { inner: false }, @@ -115,6 +122,7 @@ impl SyncTable { ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::Default, }) @@ -172,6 +180,7 @@ impl SyncTable { fn try_claim_transferred<'me>( &'me self, zalsa: &'me Zalsa, + zalsa_local: &'me ZalsaLocal, mut entry: OccupiedEntry, reentrant: Reentrancy, ) -> Result, Box>> { @@ -195,6 +204,7 @@ impl SyncTable { Ok(ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::SelfOnly, })) @@ -214,6 +224,7 @@ impl SyncTable { Ok(ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::Default, })) @@ -295,6 +306,7 @@ pub(crate) struct ClaimGuard<'me> { zalsa: &'me Zalsa, sync_table: &'me SyncTable, mode: ReleaseMode, + zalsa_local: &'me ZalsaLocal, } impl<'me> ClaimGuard<'me> { @@ -319,10 +331,21 @@ impl<'me> ClaimGuard<'me> { "Release claim on {:?} due to panic", self.database_key_index() ); - self.release(state, WaitResult::Panicked); } + #[cold] + #[inline(never)] + fn release_cancelled(&self) { + let mut syncs = self.sync_table.syncs.lock(); + let state = syncs.remove(&self.key_index).expect("key claimed twice?"); + tracing::debug!( + "Release claim on {:?} due to cancellation", + self.database_key_index() + ); + self.release(state, WaitResult::Cancelled); + } + #[inline(always)] fn release(&self, state: SyncState, wait_result: WaitResult) { let SyncState { @@ -446,7 +469,11 @@ impl<'me> ClaimGuard<'me> { impl Drop for ClaimGuard<'_> { fn drop(&mut self) { if thread::panicking() { - self.release_panicking(); + if self.zalsa_local.is_cancelled() { + self.release_cancelled(); + } else { + self.release_panicking(); + } return; } diff --git a/src/interned.rs b/src/interned.rs index 544a8d0ee..afd15f71e 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -850,6 +850,7 @@ pub struct StructEntry<'db, C> where C: Configuration, { + #[allow(dead_code)] value: &'db Value, key: DatabaseKeyIndex, } diff --git a/src/lib.rs b/src/lib.rs index f90fce338..d2324156f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,7 @@ pub use self::runtime::Runtime; pub use self::storage::{Storage, StorageHandle}; pub use self::update::Update; pub use self::zalsa::IngredientIndex; +pub use self::zalsa_local::CancellationToken; pub use crate::attach::{attach, attach_allow_change, with_attached_database}; pub mod prelude { diff --git a/src/runtime.rs b/src/runtime.rs index 48caf53ec..a7feb2287 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -13,11 +13,11 @@ mod dependency_graph; #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Runtime { - /// Set to true when the current revision has been canceled. + /// Set to true when the current revision has been cancelled. /// This is done when we an input is being changed. The flag /// is set back to false once the input has been changed. #[cfg_attr(feature = "persistence", serde(skip))] - revision_canceled: AtomicBool, + revision_cancelled: AtomicBool, /// Stores the "last change" revision for values of each duration. /// This vector is always of length at least 1 (for Durability 0) @@ -44,6 +44,7 @@ pub struct Runtime { pub(super) enum WaitResult { Completed, Panicked, + Cancelled, } #[derive(Debug)] @@ -121,7 +122,11 @@ struct BlockedOnInner<'me> { impl Running<'_> { /// Blocks on the other thread to complete the computation. - pub(crate) fn block_on(self, zalsa: &Zalsa) { + /// + /// Returns `true` if the computation was successful, and `false` if the other thread was cancelled. + #[must_use] + #[cold] + pub(crate) fn block_on(self, zalsa: &Zalsa) -> bool { let BlockedOnInner { dg, query_mutex_guard, @@ -151,7 +156,8 @@ impl Running<'_> { // by the other thread and responded to appropriately. Cancelled::PropagatedPanic.throw() } - WaitResult::Completed => {} + WaitResult::Cancelled => false, + WaitResult::Completed => true, } } } @@ -183,7 +189,7 @@ impl Default for Runtime { fn default() -> Self { Runtime { revisions: [Revision::start(); Durability::LEN], - revision_canceled: Default::default(), + revision_cancelled: Default::default(), dependency_graph: Default::default(), table: Default::default(), } @@ -194,7 +200,7 @@ impl std::fmt::Debug for Runtime { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt.debug_struct("Runtime") .field("revisions", &self.revisions) - .field("revision_canceled", &self.revision_canceled) + .field("revision_cancelled", &self.revision_cancelled) .field("dependency_graph", &self.dependency_graph) .finish() } @@ -227,16 +233,16 @@ impl Runtime { } pub(crate) fn load_cancellation_flag(&self) -> bool { - self.revision_canceled.load(Ordering::Acquire) + self.revision_cancelled.load(Ordering::Acquire) } pub(crate) fn set_cancellation_flag(&self) { crate::tracing::trace!("set_cancellation_flag"); - self.revision_canceled.store(true, Ordering::Release); + self.revision_cancelled.store(true, Ordering::Release); } pub(crate) fn reset_cancellation_flag(&mut self) { - *self.revision_canceled.get_mut() = false; + *self.revision_cancelled.get_mut() = false; } /// Returns the [`Table`] used to store the value of salsa structs diff --git a/src/storage.rs b/src/storage.rs index b6c532709..c2a7029ee 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -136,12 +136,15 @@ impl Storage { .record_unfilled_pages(self.handle.zalsa_impl.table()); let Self { handle, - zalsa_local: _, - } = &self; + zalsa_local, + } = &mut self; // Avoid rust's annoying destructure prevention rules for `Drop` types // SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure // above makes sure we won't forget to take into account newly added fields. let handle = unsafe { std::ptr::read(handle) }; + // SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure + // above makes sure we won't forget to take into account newly added fields. + unsafe { std::ptr::drop_in_place(zalsa_local) }; std::mem::forget::(self); handle } diff --git a/src/zalsa.rs b/src/zalsa.rs index f2f520b49..e9fb6d8de 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -296,8 +296,11 @@ impl Zalsa { #[inline] pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) { self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation)); + if zalsa_local.is_cancelled() { + zalsa_local.unwind_cancelled(); + } if self.runtime().load_cancellation_flag() { - zalsa_local.unwind_cancelled(self.current_revision()); + zalsa_local.unwind_pending_write(); } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 8f0239e56..0385fee80 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -3,6 +3,8 @@ use std::fmt; use std::fmt::Formatter; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; +use std::sync::atomic::Ordering; +use std::sync::Arc; use rustc_hash::FxHashMap; use thin_vec::ThinVec; @@ -39,6 +41,28 @@ pub struct ZalsaLocal { /// Stores the most recent page for a given ingredient. /// This is thread-local to avoid contention. most_recent_pages: UnsafeCell>, + + cancelled: CancellationToken, +} + +/// A cancellation token that can be used to cancel a query computation for a specific local `Database`. +#[derive(Default, Clone, Debug)] +pub struct CancellationToken(Arc); + +impl CancellationToken { + /// Inform the database to cancel the current query computation. + pub fn cancel(&self) { + self.0.store(true, Ordering::Relaxed); + } + + /// Check if the query computation has been requested to be cancelled. + pub fn is_cancelled(&self) -> bool { + self.0.load(Ordering::Relaxed) + } + + pub(crate) fn uncancel(&self) { + self.0.store(false, Ordering::Relaxed); + } } impl ZalsaLocal { @@ -46,6 +70,7 @@ impl ZalsaLocal { ZalsaLocal { query_stack: RefCell::new(QueryStack::default()), most_recent_pages: UnsafeCell::new(FxHashMap::default()), + cancelled: CancellationToken::default(), } } @@ -401,12 +426,30 @@ impl ZalsaLocal { } } + #[inline] + pub(crate) fn cancellation_token(&self) -> CancellationToken { + self.cancelled.clone() + } + + #[inline] + pub(crate) fn uncancel(&self) { + self.cancelled.uncancel(); + } + + #[inline] + pub fn is_cancelled(&self) -> bool { + self.cancelled.0.load(Ordering::Relaxed) + } + #[cold] - pub(crate) fn unwind_cancelled(&self, current_revision: Revision) { - // Why is this reporting an untracked read? We do not store the query revisions on unwind do we? - self.report_untracked_read(current_revision); + pub(crate) fn unwind_pending_write(&self) { Cancelled::PendingWrite.throw(); } + + #[cold] + pub(crate) fn unwind_cancelled(&self) { + Cancelled::Cancelled.throw(); + } } // Okay to implement as `ZalsaLocal`` is !Sync diff --git a/tests/cancellation_token.rs b/tests/cancellation_token.rs new file mode 100644 index 000000000..9aec792b6 --- /dev/null +++ b/tests/cancellation_token.rs @@ -0,0 +1,67 @@ +#![cfg(feature = "inventory")] +//! Test that `DeriveWithDb` is correctly derived. + +mod common; + +use std::{sync::Barrier, thread}; + +use expect_test::expect; +use salsa::{Cancelled, Database}; + +use crate::common::LogDatabase; + +#[salsa::input(debug)] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +fn a(db: &dyn Database, input: MyInput) -> u32 { + BARRIER.wait(); + BARRIER2.wait(); + b(db, input) +} +#[salsa::tracked] +fn b(db: &dyn Database, input: MyInput) -> u32 { + input.field(db) +} + +static BARRIER: Barrier = Barrier::new(2); +static BARRIER2: Barrier = Barrier::new(2); + +#[test] +fn cancellation_token() { + let db = common::EventLoggerDatabase::default(); + let token = db.cancellation_token(); + let input = MyInput::new(&db, 22); + let res = Cancelled::catch(|| { + thread::scope(|s| { + s.spawn(|| { + BARRIER.wait(); + token.cancel(); + BARRIER2.wait(); + }); + a(&db, input) + }) + }); + assert!(matches!(res, Err(Cancelled::Cancelled)), "{res:?}"); + drop(res); + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: a(Id(0)) }", + "WillCheckCancellation", + ]"#]]); + thread::spawn(|| { + BARRIER.wait(); + BARRIER2.wait(); + }); + a(&db, input); + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: a(Id(0)) }", + "WillCheckCancellation", + "WillExecute { database_key: b(Id(0)) }", + ]"#]]); +} diff --git a/tests/interned-revisions.rs b/tests/interned-revisions.rs index bef1db61c..41f762895 100644 --- a/tests/interned-revisions.rs +++ b/tests/interned-revisions.rs @@ -156,7 +156,7 @@ fn test_immortal() { // Modify the input to bump the revision and intern a new value. // // No values should ever be reused with `durability = usize::MAX`. - for i in 1..100 { + for i in 1..if cfg!(miri) { 50 } else { 1000 } { input.set_field1(&mut db).to(i); let result = function(&db, input); assert_eq!(result.field1(&db).0, i); diff --git a/tests/parallel/cancellation_token_recomputes.rs b/tests/parallel/cancellation_token_recomputes.rs new file mode 100644 index 000000000..b7963afc5 --- /dev/null +++ b/tests/parallel/cancellation_token_recomputes.rs @@ -0,0 +1,43 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation when another query is blocked on the cancelled thread. +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::tracked] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + query_b(db) +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + db.signal(1); + db.wait_for(3); + query_c(db) +} + +#[salsa::tracked] +fn query_c(_db: &dyn KnobsDatabase) -> u32 { + 1 +} +#[test] +fn execute() { + let db = Knobs::default(); + let db2 = db.clone(); + let db_signaler = db.clone(); + let token = db.cancellation_token(); + + let t1 = std::thread::spawn(move || query_a(&db)); + db_signaler.wait_for(1); + db2.signal_on_will_block(2); + let t2 = std::thread::spawn(move || query_a(&db2)); + db_signaler.wait_for(2); + token.cancel(); + db_signaler.signal(3); + let (r1, r2) = (t1.join(), t2.join()); + let r1 = *r1.unwrap_err().downcast::().unwrap(); + assert!(matches!(r1, Cancelled::Cancelled), "{r1:?}"); + assert_eq!(r2.unwrap(), 1); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 1062d4899..41a6f7453 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,7 @@ mod setup; mod signal; +mod cancellation_token_recomputes; mod cycle_a_t1_b_t2; mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; From 796d08bb6e2d6f4068af5a8c99adb7aa4154fedc Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 30 Oct 2025 08:38:10 +0100 Subject: [PATCH 2/7] Address reviews --- src/cancelled.rs | 4 +- src/database.rs | 2 +- src/function/fetch.rs | 54 +++++++++---------- src/function/maybe_changed_after.rs | 12 ++--- src/function/memo.rs | 2 +- src/runtime.rs | 7 ++- src/zalsa_local.rs | 2 +- tests/cancellation_token.rs | 2 +- .../parallel/cancellation_token_recomputes.rs | 2 +- 9 files changed, 43 insertions(+), 44 deletions(-) diff --git a/src/cancelled.rs b/src/cancelled.rs index 1fa0edc59..5fe69e7d1 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -11,7 +11,7 @@ use std::panic::{self, UnwindSafe}; #[non_exhaustive] pub enum Cancelled { /// The query was operating but the local database execution has been cancelled. - Cancelled, + Local, /// The query was operating on revision R, but there is a pending write to move to revision R+1. PendingWrite, @@ -46,7 +46,7 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let why = match self { - Cancelled::Cancelled => "canellation request", + Cancelled::Local => "local canellation request", Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/src/database.rs b/src/database.rs index 9cb70f917..0831fd5bf 100644 --- a/src/database.rs +++ b/src/database.rs @@ -68,7 +68,7 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { let _ = self.zalsa_mut(); } - /// Retrives a [`CancellationToken`] for the current database. + /// Retrieves a [`CancellationToken`] for the current database handle. fn cancellation_token(&self) -> CancellationToken { self.zalsa_local().cancellation_token() } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 313bd43ee..7abc6236b 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -105,39 +105,37 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = loop { - match self - .sync_table - .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) - { - ClaimResult::Claimed(guard) => break guard, - ClaimResult::Running(blocked_on) => { - if !blocked_on.block_on(zalsa) { - continue; - } + let claim_guard = match self + .sync_table + .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) + { + ClaimResult::Claimed(guard) => guard, + ClaimResult::Running(blocked_on) => { + if !blocked_on.block_on(zalsa) { + return None; + } - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { - let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo { - if memo.value.is_some() { - memo.block_on_heads(zalsa); - } + if let Some(memo) = memo { + if memo.value.is_some() { + memo.block_on_heads(zalsa); } } - - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.fetch_cold_cycle( - zalsa, - zalsa_local, - db, - id, - database_key_index, - memo_ingredient_index, - )); } + + return None; + } + ClaimResult::Cycle { .. } => { + return Some(self.fetch_cold_cycle( + zalsa, + zalsa_local, + db, + id, + database_key_index, + memo_ingredient_index, + )); } }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index fc9d7ac88..53ad0dafb 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -141,16 +141,15 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let claim_guard = loop { + let claim_guard = match self .sync_table .try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny) { - ClaimResult::Claimed(guard) => break guard, + ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { - if blocked_on.block_on(zalsa) { - return None; - } + _ = blocked_on.block_on(zalsa); + return None; } ClaimResult::Cycle { .. } => { return Some(self.maybe_changed_after_cold_cycle( @@ -159,8 +158,7 @@ where cycle_heads, )) } - } - }; + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) else { diff --git a/src/function/memo.rs b/src/function/memo.rs index ac259b4b0..c54ff0b93 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -504,7 +504,7 @@ mod _memory_usage { use std::any::TypeId; use std::num::NonZeroUsize; - // Memo's are stored a lot, make sure their size is doesn't randomly increase. + // Memo's are stored a lot, make sure their size doesn't randomly increase. const _: [(); std::mem::size_of::>()] = [(); std::mem::size_of::<[usize; 6]>()]; diff --git a/src/runtime.rs b/src/runtime.rs index a7feb2287..5b36bf205 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -123,9 +123,12 @@ struct BlockedOnInner<'me> { impl Running<'_> { /// Blocks on the other thread to complete the computation. /// - /// Returns `true` if the computation was successful, and `false` if the other thread was cancelled. + /// Returns `true` if the computation was successful, and `false` if the other thread was locally cancelled. + /// + /// # Panics + /// + /// If the other thread panics, this function will panic as well. #[must_use] - #[cold] pub(crate) fn block_on(self, zalsa: &Zalsa) -> bool { let BlockedOnInner { dg, diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 0385fee80..cb1a1cdd2 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -448,7 +448,7 @@ impl ZalsaLocal { #[cold] pub(crate) fn unwind_cancelled(&self) { - Cancelled::Cancelled.throw(); + Cancelled::Local.throw(); } } diff --git a/tests/cancellation_token.rs b/tests/cancellation_token.rs index 9aec792b6..f6a14930a 100644 --- a/tests/cancellation_token.rs +++ b/tests/cancellation_token.rs @@ -44,7 +44,7 @@ fn cancellation_token() { a(&db, input) }) }); - assert!(matches!(res, Err(Cancelled::Cancelled)), "{res:?}"); + assert!(matches!(res, Err(Cancelled::Local)), "{res:?}"); drop(res); db.assert_logs(expect![[r#" [ diff --git a/tests/parallel/cancellation_token_recomputes.rs b/tests/parallel/cancellation_token_recomputes.rs index b7963afc5..0bbb67ef0 100644 --- a/tests/parallel/cancellation_token_recomputes.rs +++ b/tests/parallel/cancellation_token_recomputes.rs @@ -38,6 +38,6 @@ fn execute() { db_signaler.signal(3); let (r1, r2) = (t1.join(), t2.join()); let r1 = *r1.unwrap_err().downcast::().unwrap(); - assert!(matches!(r1, Cancelled::Cancelled), "{r1:?}"); + assert!(matches!(r1, Cancelled::Local), "{r1:?}"); assert_eq!(r2.unwrap(), 1); } From 2141ac9d65efd96998282af39e54cb3b368968db Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 21 Nov 2025 10:03:08 +0100 Subject: [PATCH 3/7] Disable local cancellation while within a cycle computation --- src/function/execute.rs | 19 ++++++++++++------- src/function/memo.rs | 6 +++++- src/function/sync.rs | 4 ++-- src/zalsa.rs | 2 +- src/zalsa_local.rs | 33 ++++++++++++++++++++++++--------- 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index a4dbe4986..7edba8253 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -109,13 +109,18 @@ where (new_value, completed_query) } - CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate( - db, - opt_old_memo, - &mut claim_guard, - zalsa_local, - memo_ingredient_index, - ), + CycleRecoveryStrategy::Fixpoint => { + let was_disabled = zalsa_local.set_cancellation_disabled(true); + let res = self.execute_maybe_iterate( + db, + opt_old_memo, + &mut claim_guard, + zalsa_local, + memo_ingredient_index, + ); + zalsa_local.set_cancellation_disabled(was_disabled); + res + } }; if let Some(old_memo) = opt_old_memo { diff --git a/src/function/memo.rs b/src/function/memo.rs index c54ff0b93..b0f1af37d 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -181,7 +181,11 @@ impl<'db, C: Configuration> Memo<'db, C> { TryClaimHeadsResult::Running(running) => { all_cycles = false; if !running.block_on(zalsa) { - // FIXME: Handle cancellation properly? + // We cannot handle local cancellations in fixpoints + // so we treat it as a general cancellation / panic. + // + // We shouldn't hit this though as we disable local cancellation + // in cycles. crate::Cancelled::PropagatedPanic.throw(); } } diff --git a/src/function/sync.rs b/src/function/sync.rs index edc881ffa..c97815ecb 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -288,7 +288,7 @@ pub enum SyncOwner { /// The query's lock ownership has been transferred to another query. /// E.g. if `a` transfers its ownership to `b`, then only the thread in the critical path - /// to complete b` can claim `a` (in most instances, only the thread owning `b` can claim `a`). + /// to complete `b` can claim `a` (in most instances, only the thread owning `b` can claim `a`). /// /// The thread owning `a` is stored in the `DependencyGraph`. /// @@ -469,7 +469,7 @@ impl<'me> ClaimGuard<'me> { impl Drop for ClaimGuard<'_> { fn drop(&mut self) { if thread::panicking() { - if self.zalsa_local.is_cancelled() { + if self.zalsa_local.should_trigger_local_cancellation() { self.release_cancelled(); } else { self.release_panicking(); diff --git a/src/zalsa.rs b/src/zalsa.rs index e9fb6d8de..d550b4c1f 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -296,7 +296,7 @@ impl Zalsa { #[inline] pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) { self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation)); - if zalsa_local.is_cancelled() { + if zalsa_local.should_trigger_local_cancellation() { zalsa_local.unwind_cancelled(); } if self.runtime().load_cancellation_flag() { diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index cb1a1cdd2..b05c4b4ac 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -3,7 +3,7 @@ use std::fmt; use std::fmt::Formatter; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use rustc_hash::FxHashMap; @@ -47,21 +47,32 @@ pub struct ZalsaLocal { /// A cancellation token that can be used to cancel a query computation for a specific local `Database`. #[derive(Default, Clone, Debug)] -pub struct CancellationToken(Arc); +pub struct CancellationToken(Arc); impl CancellationToken { + const CANCELLED_MASK: u8 = 0b01; + const DISABLED_MASK: u8 = 0b10; + /// Inform the database to cancel the current query computation. pub fn cancel(&self) { - self.0.store(true, Ordering::Relaxed); + self.0.fetch_or(Self::CANCELLED_MASK, Ordering::Relaxed); } /// Check if the query computation has been requested to be cancelled. pub fn is_cancelled(&self) -> bool { - self.0.load(Ordering::Relaxed) + self.0.load(Ordering::Relaxed) & Self::CANCELLED_MASK != 0 } - pub(crate) fn uncancel(&self) { - self.0.store(false, Ordering::Relaxed); + fn set_cancellation_disabled(&self, disabled: bool) -> bool { + self.0.fetch_or((disabled as u8) << 1, Ordering::Relaxed) & Self::DISABLED_MASK != 0 + } + + fn should_trigger_local_cancellation(&self) -> bool { + self.0.load(Ordering::Relaxed) == Self::CANCELLED_MASK + } + + fn reset(&self) { + self.0.store(0, Ordering::Relaxed); } } @@ -433,12 +444,12 @@ impl ZalsaLocal { #[inline] pub(crate) fn uncancel(&self) { - self.cancelled.uncancel(); + self.cancelled.reset(); } #[inline] - pub fn is_cancelled(&self) -> bool { - self.cancelled.0.load(Ordering::Relaxed) + pub fn should_trigger_local_cancellation(&self) -> bool { + self.cancelled.should_trigger_local_cancellation() } #[cold] @@ -450,6 +461,10 @@ impl ZalsaLocal { pub(crate) fn unwind_cancelled(&self) { Cancelled::Local.throw(); } + + pub(crate) fn set_cancellation_disabled(&self, was_disabled: bool) -> bool { + self.cancelled.set_cancellation_disabled(was_disabled) + } } // Okay to implement as `ZalsaLocal`` is !Sync From c07e29777d336f3b3d7cee9f313be281869959c6 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 26 Dec 2025 13:09:03 +0100 Subject: [PATCH 4/7] Add cycle test --- .github/workflows/test.yml | 2 +- src/cancelled.rs | 2 +- .../cancellation_token_cycle_nested.rs | 146 ++++++++++++++++++ .../cancellation_token_multi_blocked.rs | 86 +++++++++++ tests/parallel/main.rs | 2 + 5 files changed, 236 insertions(+), 2 deletions(-) create mode 100644 tests/parallel/cancellation_token_cycle_nested.rs create mode 100644 tests/parallel/cancellation_token_multi_blocked.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ee9aa06b8..1defab25b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -91,7 +91,7 @@ jobs: - name: Test with Miri run: cargo miri nextest run --no-fail-fast --tests env: - MIRIFLAGS: -Zmiri-disable-isolation -Zmiri-retag-fields + MIRIFLAGS: -Zmiri-disable-isolation - name: Run examples with Miri run: cargo miri run --example calc diff --git a/src/cancelled.rs b/src/cancelled.rs index 5fe69e7d1..e690eac35 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -46,7 +46,7 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let why = match self { - Cancelled::Local => "local canellation request", + Cancelled::Local => "local cancellation request", Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/tests/parallel/cancellation_token_cycle_nested.rs b/tests/parallel/cancellation_token_cycle_nested.rs new file mode 100644 index 000000000..65f8ad7a6 --- /dev/null +++ b/tests/parallel/cancellation_token_cycle_nested.rs @@ -0,0 +1,146 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation with deeply nested cycles across multiple threads. +//! +//! These tests verify that local cancellation is disabled during cycle iteration, +//! allowing multi-threaded cycles to complete successfully before cancellation +//! can take effect. +use salsa::Database; + +use crate::setup::{Knobs, KnobsDatabase}; +use crate::sync::thread; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +#[salsa::tracked(cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let c_value = query_c(db); + CycleValue(c_value.0 + 1).min(MAX) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + let d_value = query_d(db); + let e_value = query_e(db); + let b_value = query_b(db); + let a_value = query_a(db); + CycleValue(d_value.0.max(e_value.0).max(b_value.0).max(a_value.0)) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_d(db: &dyn KnobsDatabase) -> CycleValue { + query_c(db) +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_e(db: &dyn KnobsDatabase) -> CycleValue { + query_c(db) +} + +#[salsa::tracked] +fn query_f(db: &dyn KnobsDatabase) -> CycleValue { + let c = query_c(db); + // this should trigger cancellation again + query_h(db); + c +} + +#[salsa::tracked] +fn query_h(db: &dyn KnobsDatabase) { + _ = db; +} + +fn initial(db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { + db.signal(1); + db.wait_for(6); + MIN +} + +/// Test that a multi-threaded cycle completes successfully even when +/// cancellation is requested during the cycle. +/// +/// This test is similar to cycle_nested_deep but adds cancellation during +/// the cycle to verify that cancellation is properly deferred. +#[test] +fn multi_threaded_cycle_completes_despite_cancellation() { + let db = Knobs::default(); + let db_t1 = db.clone(); + let db_t2 = db.clone(); + let db_t3 = db.clone(); + let db_t4 = db.clone(); + let db_t5 = db.clone(); + let db_signaler = db; + + let token_t1 = db_t1.cancellation_token(); + let token_t2 = db_t2.cancellation_token(); + let token_t3 = db_t3.cancellation_token(); + let token_t5 = db_t5.cancellation_token(); + + // Thread 1: Runs the main cycle, will have cancellation requested during it + let t1 = thread::spawn(move || query_a(&db_t1)); + + // Wait for t1 to start the cycle + db_signaler.wait_for(1); + + // Spawn t2 and wait for it to block on the cycle + db_signaler.signal_on_will_block(2); + let t2 = thread::spawn(move || query_b(&db_t2)); + db_signaler.wait_for(2); + + // Spawn t3 and wait for it to block on the cycle + db_signaler.signal_on_will_block(3); + let t3 = thread::spawn(move || query_d(&db_t3)); + db_signaler.wait_for(3); + + // Spawn t4 - doesn't get cancelled + db_signaler.signal_on_will_block(4); + let t4 = thread::spawn(move || query_e(&db_t4)); + db_signaler.wait_for(4); + + // Spawn t4 - doesn't get cancelled + db_signaler.signal_on_will_block(5); + let t5 = thread::spawn(move || query_f(&db_t5)); + db_signaler.wait_for(5); + + // Request cancellation while t2 and t3 are blocked on the cycle + // This should be deferred until after the cycle completes + token_t1.cancel(); + token_t2.cancel(); + token_t3.cancel(); + token_t5.cancel(); + + // Let t1 continue - the cycle should still complete because + // cancellation is disabled during fixpoint iteration + db_signaler.signal(6); + + // All threads should complete successfully + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + let r_t4 = t4.join().unwrap(); + + let r_t5 = t5.join().unwrap_err(); + + // All should get MAX because cycles defer cancellation + assert_eq!(r_t1, MAX, "t1 should get MAX"); + assert_eq!(r_t2, MAX, "t2 should get MAX"); + assert_eq!(r_t3, MAX, "t3 should get MAX"); + assert_eq!(r_t4, MAX, "t4 should get MAX"); + assert!( + matches!( + *r_t5.downcast::().unwrap(), + salsa::Cancelled::Local + ), + "t5 should be cancelled as its blocked on the cycle, not participating in it and calling an uncomputed query after" + ); +} diff --git a/tests/parallel/cancellation_token_multi_blocked.rs b/tests/parallel/cancellation_token_multi_blocked.rs new file mode 100644 index 000000000..3e3caf231 --- /dev/null +++ b/tests/parallel/cancellation_token_multi_blocked.rs @@ -0,0 +1,86 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation when multiple queries are blocked on the cancelled thread. +//! +//! This test verifies that: +//! 1. When a thread is cancelled, blocked threads recompute rather than propagate cancellation +//! 2. The final result is correctly computed by the remaining threads +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::tracked] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + query_b(db) +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + // Signal that t1 has started computing query_b + db.signal(1); + // Wait for t2 and t3 to block on us + db.wait_for(3); + // Wait for cancellation to happen + db.wait_for(4); + query_c(db) +} + +#[salsa::tracked] +fn query_c(_db: &dyn KnobsDatabase) -> u32 { + 42 +} + +/// Test that when a thread is cancelled, other blocked threads successfully +/// recompute the query and get the correct result. +#[test] +fn multiple_threads_blocked_on_cancelled() { + let db = Knobs::default(); + let db2 = db.clone(); + let db3 = db.clone(); + let db_signaler = db.clone(); + let token = db.cancellation_token(); + + // Thread 1: Starts computing query_a -> query_b, will be cancelled + let t1 = std::thread::spawn(move || query_a(&db)); + + // Wait for t1 to start query_b + db_signaler.wait_for(1); + + // Thread 2: Will block on query_a (which is blocked on query_b) + db2.signal_on_will_block(2); + let t2 = std::thread::spawn(move || query_a(&db2)); + + // Wait for t2 to block + db_signaler.wait_for(2); + + // Thread 3: Also blocks on query_a + db3.signal_on_will_block(3); + let t3 = std::thread::spawn(move || query_a(&db3)); + + // Wait for t3 to block + db_signaler.wait_for(3); + + // Now cancel t1 + token.cancel(); + + // Let t1 continue and get cancelled + db_signaler.signal(4); + + // Collect results + let r1 = t1.join(); + let r2 = t2.join(); + let r3 = t3.join(); + + // t1 should have been cancelled + let r1_cancelled = r1.unwrap_err().downcast::().map(|c| *c); + assert!( + matches!(r1_cancelled, Ok(Cancelled::Local)), + "t1 should be locally cancelled, got: {:?}", + r1_cancelled + ); + + // t2 and t3 should both succeed with the correct value + assert_eq!(r2.unwrap(), 42, "t2 should compute the correct result"); + assert_eq!(r3.unwrap(), 42, "t3 should compute the correct result"); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 41a6f7453..399eaa7da 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,8 @@ mod setup; mod signal; +mod cancellation_token_cycle_nested; +mod cancellation_token_multi_blocked; mod cancellation_token_recomputes; mod cycle_a_t1_b_t2; mod cycle_a_t1_b_t2_fallback; From 48ad3314e23f66b6cf2c9f171b92330ac76dda86 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 26 Dec 2025 13:37:55 +0100 Subject: [PATCH 5/7] Fix `set_cancellation_disabled` implementation --- src/zalsa_local.rs | 9 ++++++++- tests/interned-revisions.rs | 2 +- tests/parallel/cancellation_token_cycle_nested.rs | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index b05c4b4ac..d60582eab 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -63,8 +63,14 @@ impl CancellationToken { self.0.load(Ordering::Relaxed) & Self::CANCELLED_MASK != 0 } + #[inline] fn set_cancellation_disabled(&self, disabled: bool) -> bool { - self.0.fetch_or((disabled as u8) << 1, Ordering::Relaxed) & Self::DISABLED_MASK != 0 + let previous_disabled_bit = if disabled { + self.0.fetch_or(Self::DISABLED_MASK, Ordering::Relaxed) + } else { + self.0.fetch_and(!Self::DISABLED_MASK, Ordering::Relaxed) + }; + previous_disabled_bit & Self::DISABLED_MASK != 0 } fn should_trigger_local_cancellation(&self) -> bool { @@ -462,6 +468,7 @@ impl ZalsaLocal { Cancelled::Local.throw(); } + #[inline] pub(crate) fn set_cancellation_disabled(&self, was_disabled: bool) -> bool { self.cancelled.set_cancellation_disabled(was_disabled) } diff --git a/tests/interned-revisions.rs b/tests/interned-revisions.rs index 41f762895..02d7d8112 100644 --- a/tests/interned-revisions.rs +++ b/tests/interned-revisions.rs @@ -155,7 +155,7 @@ fn test_immortal() { // Modify the input to bump the revision and intern a new value. // - // No values should ever be reused with `durability = usize::MAX`. + // No values should ever be reused with `revisions = usize::MAX`. for i in 1..if cfg!(miri) { 50 } else { 1000 } { input.set_field1(&mut db).to(i); let result = function(&db, input); diff --git a/tests/parallel/cancellation_token_cycle_nested.rs b/tests/parallel/cancellation_token_cycle_nested.rs index 65f8ad7a6..b1e8fbb62 100644 --- a/tests/parallel/cancellation_token_cycle_nested.rs +++ b/tests/parallel/cancellation_token_cycle_nested.rs @@ -107,7 +107,7 @@ fn multi_threaded_cycle_completes_despite_cancellation() { let t4 = thread::spawn(move || query_e(&db_t4)); db_signaler.wait_for(4); - // Spawn t4 - doesn't get cancelled + // Spawn t5 - doesn't get cancelled db_signaler.signal_on_will_block(5); let t5 = thread::spawn(move || query_f(&db_t5)); db_signaler.wait_for(5); From 08d898197fab08a4064a05941e76084156ea3865 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Wed, 31 Dec 2025 16:34:17 +0100 Subject: [PATCH 6/7] Expose `zalsa_local` from `ClaimGuard` --- src/function/execute.rs | 28 ++++++++++++++++-------- src/function/fetch.rs | 2 +- src/function/maybe_changed_after.rs | 2 +- src/function/sync.rs | 34 ++++++++++++----------------- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 7edba8253..11d8f0fdd 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -38,7 +38,6 @@ where &'db self, db: &'db C::DbView, mut claim_guard: ClaimGuard<'db>, - zalsa_local: &'db ZalsaLocal, opt_old_memo: Option<&Memo<'db, C>>, ) -> Option<&'db Memo<'db, C>> { let database_key_index = claim_guard.database_key_index(); @@ -60,7 +59,9 @@ where let (new_value, active_query) = Self::execute_query( db, zalsa, - zalsa_local.push_query(database_key_index, IterationCount::initial()), + claim_guard + .zalsa_local() + .push_query(database_key_index, IterationCount::initial()), opt_old_memo, ); (new_value, active_query.pop()) @@ -69,7 +70,9 @@ where let (mut new_value, active_query) = Self::execute_query( db, zalsa, - zalsa_local.push_query(database_key_index, IterationCount::initial()), + claim_guard + .zalsa_local() + .push_query(database_key_index, IterationCount::initial()), opt_old_memo, ); @@ -97,8 +100,9 @@ where // Cycle participants that don't have a fallback will be discarded in // `validate_provisional()`. let cycle_heads = std::mem::take(cycle_heads); - let active_query = - zalsa_local.push_query(database_key_index, IterationCount::initial()); + let active_query = claim_guard + .zalsa_local() + .push_query(database_key_index, IterationCount::initial()); new_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); completed_query = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. @@ -110,12 +114,12 @@ where (new_value, completed_query) } CycleRecoveryStrategy::Fixpoint => { + let zalsa_local = claim_guard.zalsa_local(); let was_disabled = zalsa_local.set_cancellation_disabled(true); let res = self.execute_maybe_iterate( db, opt_old_memo, &mut claim_guard, - zalsa_local, memo_ingredient_index, ); zalsa_local.set_cancellation_disabled(was_disabled); @@ -163,7 +167,6 @@ where db: &'db C::DbView, opt_old_memo: Option<&Memo<'db, C>>, claim_guard: &mut ClaimGuard<'db>, - zalsa_local: &'db ZalsaLocal, memo_ingredient_index: MemoIngredientIndex, ) -> (C::Output<'db>, CompletedQuery) { claim_guard.set_release_mode(ReleaseMode::Default); @@ -210,7 +213,9 @@ where PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); let (new_value, completed_query) = loop { - let active_query = zalsa_local.push_query(database_key_index, iteration_count); + let active_query = claim_guard + .zalsa_local() + .push_query(database_key_index, iteration_count); // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next @@ -253,7 +258,12 @@ where iteration_count, ); - let outer_cycle = outer_cycle(zalsa, zalsa_local, &cycle_heads, database_key_index); + let outer_cycle = outer_cycle( + zalsa, + claim_guard.zalsa_local(), + &cycle_heads, + database_key_index, + ); // Did the new result we got depend on our own provisional value, in a cycle? // If not, return because this query is not a cycle head. diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 7abc6236b..48eee089f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -181,7 +181,7 @@ where } } - self.execute(db, claim_guard, zalsa_local, opt_old_memo) + self.execute(db, claim_guard, opt_old_memo) } #[cold] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 53ad0dafb..f25d955e1 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -229,7 +229,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo))?; + let memo = self.execute(db, claim_guard, Some(old_memo))?; let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. diff --git a/src/function/sync.rs b/src/function/sync.rs index c97815ecb..f01932a37 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -314,6 +314,10 @@ impl<'me> ClaimGuard<'me> { self.zalsa } + pub(crate) fn zalsa_local(&self) -> &'me ZalsaLocal { + self.zalsa_local + } + pub(crate) const fn database_key_index(&self) -> DatabaseKeyIndex { DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index) } @@ -327,23 +331,17 @@ impl<'me> ClaimGuard<'me> { fn release_panicking(&self) { let mut syncs = self.sync_table.syncs.lock(); let state = syncs.remove(&self.key_index).expect("key claimed twice?"); + let result = if self.zalsa_local.should_trigger_local_cancellation() { + WaitResult::Cancelled + } else { + WaitResult::Panicked + }; tracing::debug!( - "Release claim on {:?} due to panic", - self.database_key_index() - ); - self.release(state, WaitResult::Panicked); - } - - #[cold] - #[inline(never)] - fn release_cancelled(&self) { - let mut syncs = self.sync_table.syncs.lock(); - let state = syncs.remove(&self.key_index).expect("key claimed twice?"); - tracing::debug!( - "Release claim on {:?} due to cancellation", - self.database_key_index() + "Release claim on {:?} due to {:?}", + self.database_key_index(), + result ); - self.release(state, WaitResult::Cancelled); + self.release(state, result); } #[inline(always)] @@ -469,11 +467,7 @@ impl<'me> ClaimGuard<'me> { impl Drop for ClaimGuard<'_> { fn drop(&mut self) { if thread::panicking() { - if self.zalsa_local.should_trigger_local_cancellation() { - self.release_cancelled(); - } else { - self.release_panicking(); - } + self.release_panicking(); return; } From 2f998273f28eb35a5329aef148a8f1c00135cf29 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Wed, 31 Dec 2025 16:41:34 +0100 Subject: [PATCH 7/7] Document attached states --- src/attach.rs | 17 +++++++++++++++-- src/function/memo.rs | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/attach.rs b/src/attach.rs index 3a6edc93f..072b82238 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -40,6 +40,9 @@ impl Attached { Db: ?Sized + Database, { struct DbGuard<'s> { + /// The database that *we* attached on scope entry. + /// + /// `None` if one was already attached by a parent scope. state: Option<&'s Attached>, } @@ -47,6 +50,7 @@ impl Attached { #[inline] fn new(attached: &'s Attached, db: &dyn Database) -> Self { match attached.database.get() { + // A database is already attached, make sure it's the same as the new one. Some(current_db) => { let new_db = NonNull::from(db); if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) { @@ -54,8 +58,8 @@ impl Attached { } Self { state: None } } + // No database is attached, attach the new one. None => { - // Otherwise, set the database. attached.database.set(Some(NonNull::from(db))); Self { state: Some(attached), @@ -88,7 +92,13 @@ impl Attached { Db: ?Sized + Database, { struct DbGuard<'s> { + /// The database that *we* attached on scope entry. + /// + /// `None` if one was already attached by a parent scope. state: Option<&'s Attached>, + /// The previously attached database that we replaced, if any. + /// + /// We need to make sure to rollback and activate it again when we exit the scope. prev: Option>, } @@ -97,21 +107,24 @@ impl Attached { fn new(attached: &'s Attached, db: &dyn Database) -> Self { let db = NonNull::from(db); match attached.database.replace(Some(db)) { + // A database was already attached by a parent scope. Some(prev) => { if std::ptr::eq(db.as_ptr(), prev.as_ptr()) { + // and it was the same as ours, so we did not change anything. Self { state: None, prev: None, } } else { + // and it was the a different one from ours, record the state changes. Self { state: Some(attached), prev: Some(prev), } } } + // No database is attached, attach the new one. None => { - // Otherwise, set the database. attached.database.set(Some(db)); Self { state: Some(attached), diff --git a/src/function/memo.rs b/src/function/memo.rs index b0f1af37d..9ebae6c9e 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -181,7 +181,7 @@ impl<'db, C: Configuration> Memo<'db, C> { TryClaimHeadsResult::Running(running) => { all_cycles = false; if !running.block_on(zalsa) { - // We cannot handle local cancellations in fixpoints + // We cannot really handle local cancellations reliably here // so we treat it as a general cancellation / panic. // // We shouldn't hit this though as we disable local cancellation