From 9e8635c7db72403489554fe688df1218dfda7130 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 10:18:23 +0000 Subject: [PATCH 01/29] remove upcast_mut We only ever need to upcast to shared references. This change isn't necessary, just dead code cleanup. --- components/salsa-macros/src/db.rs | 2 +- src/storage.rs | 19 ++-------- src/views.rs | 58 ++----------------------------- 3 files changed, 6 insertions(+), 73 deletions(-) diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index f098524f..02fcef4b 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -123,7 +123,7 @@ impl DbMacro { #[doc(hidden)] fn zalsa_db(&self) { use salsa::plumbing as #zalsa; - #zalsa::views(self).add::(|t| t, |t| t); + #zalsa::views(self).add::(|t| t); } }); Ok(()) diff --git a/src/storage.rs b/src/storage.rs index f924069b..e28fc45c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -156,17 +156,6 @@ impl dyn Database { pub fn as_view(&self) -> &DbView { self.views().try_view_as(self).unwrap() } - - /// Upcasts `self` to the given view. - /// - /// # Panics - /// - /// If the view has not been added to the database (see [`DatabaseView`][]) - pub fn as_view_mut(&mut self) -> &mut DbView { - // Avoid a borrow check error by cloning. This is the "uncommon" path so it seems fine. - let upcasts = self.views().clone(); - upcasts.try_view_as_mut(self).unwrap() - } } /// Nonce type representing the underlying database storage. @@ -253,12 +242,8 @@ impl Default for Storage { impl Storage { /// Add an upcast function to type `T`. - pub fn add_upcast( - &mut self, - func: fn(&Db) -> &T, - func_mut: fn(&mut Db) -> &mut T, - ) { - self.upcasts.add::(func, func_mut) + pub fn add_upcast(&mut self, func: fn(&Db) -> &T) { + self.upcasts.add::(func) } /// Adds the ingredients in `jar` to the database if not already present. diff --git a/src/views.rs b/src/views.rs index 64601caa..5d0c7f47 100644 --- a/src/views.rs +++ b/src/views.rs @@ -24,7 +24,6 @@ struct ViewCaster { target_type_id: TypeId, type_name: &'static str, func: fn(&Dummy) -> &Dummy, - func_mut: fn(&mut Dummy) -> &mut Dummy, } #[allow(dead_code)] @@ -41,12 +40,8 @@ impl Default for ViewsOf { impl ViewsOf { /// Add a new upcast from `Db` to `T`, given the upcasting function `func`. - pub fn add( - &self, - func: fn(&Db) -> &DbView, - func_mut: fn(&mut Db) -> &mut DbView, - ) { - self.upcasts.add(func, func_mut); + pub fn add(&self, func: fn(&Db) -> &DbView) { + self.upcasts.add(func); } } @@ -68,11 +63,7 @@ impl Views { } /// Add a new upcast from `Db` to `T`, given the upcasting function `func`. - pub fn add( - &self, - func: fn(&Db) -> &DbView, - func_mut: fn(&mut Db) -> &mut DbView, - ) { + pub fn add(&self, func: fn(&Db) -> &DbView) { assert_eq!(self.source_type_id, TypeId::of::(), "dyn-upcasts"); let target_type_id = TypeId::of::(); @@ -89,11 +80,6 @@ impl Views { target_type_id, type_name: std::any::type_name::(), func: unsafe { std::mem::transmute:: &DbView, fn(&Dummy) -> &Dummy>(func) }, - func_mut: unsafe { - std::mem::transmute:: &mut DbView, fn(&mut Dummy) -> &mut Dummy>( - func_mut, - ) - }, }); } @@ -125,36 +111,6 @@ impl Views { None } - - /// Convert one handle to a salsa database (including a `dyn Database`!) to another. - /// - /// # Panics - /// - /// If the underlying type of `db` is not the same as the database type this upcasts was created for. - pub fn try_view_as_mut<'db, View: ?Sized + Any>( - &self, - db: &'db mut dyn Database, - ) -> Option<&'db mut View> { - let db_type_id = ::type_id(db); - assert_eq!(self.source_type_id, db_type_id, "database type mismatch"); - - let view_type_id = TypeId::of::(); - for caster in self.view_casters.iter() { - if caster.target_type_id == view_type_id { - // SAFETY: We have some function that takes a thin reference to the underlying - // database type `X` and returns a (potentially wide) reference to `View`. - // - // While the compiler doesn't know what `X` is at this point, we know it's the - // same as the true type of `db_data_ptr`, and the memory representation for `()` - // and `&X` are the same (since `X` is `Sized`). - let func_mut: fn(&mut ()) -> &mut View = - unsafe { std::mem::transmute(caster.func_mut) }; - return Some(func_mut(data_ptr_mut(db))); - } - } - - None - } } impl std::fmt::Debug for Views { @@ -179,14 +135,6 @@ fn data_ptr(t: &T) -> &() { unsafe { &*u } } -/// Given a wide pointer `T`, extracts the data pointer (typed as `()`). -/// This is safe because `()` gives no access to any data and has no validity requirements in particular. -fn data_ptr_mut(t: &mut T) -> &mut () { - let t: *mut T = t; - let u: *mut () = t as *mut (); - unsafe { &mut *u } -} - impl Clone for ViewsOf { fn clone(&self) -> Self { Self { From bc72bdf524393033694c5e72405d36680ae2e97e Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 10:32:42 +0000 Subject: [PATCH 02/29] as_salsa_database => as_dyn_database Also, move to a blanket impl'd trait. Overall cleaner approach. --- .../src/setup_accumulator_impl.rs | 2 +- .../src/setup_input_struct.rs | 10 ++--- .../src/setup_interned_struct.rs | 4 +- .../salsa-macro-rules/src/setup_tracked_fn.rs | 10 ++--- .../src/setup_tracked_struct.rs | 6 +-- src/database.rs | 22 ++++++++++- src/function.rs | 12 ++---- src/function/diff_outputs.rs | 4 +- src/function/fetch.rs | 8 ++-- src/function/maybe_changed_after.rs | 16 ++++---- src/function/specify.rs | 9 ++--- src/lib.rs | 1 + src/local_state.rs | 2 +- src/storage.rs | 39 ------------------- src/tracked_struct.rs | 2 +- 15 files changed, 61 insertions(+), 86 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index b474021e..cf36637c 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -35,7 +35,7 @@ macro_rules! setup_accumulator_impl { where Db: ?Sized + $zalsa::Database, { - let db = db.as_salsa_database(); + let db = db.as_dyn_database(); $ingredient(db).push(db, self); } } diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index c9e94df6..f320656c 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -128,7 +128,7 @@ macro_rules! setup_input_struct { { let current_revision = $zalsa::current_revision(db); let stamps = $zalsa::Array::new([$zalsa::stamp(current_revision, Default::default()); $N]); - $Configuration::ingredient(db.as_salsa_database()).new_input(($($field_id,)*), stamps) + $Configuration::ingredient(db.as_dyn_database()).new_input(($($field_id,)*), stamps) } $( @@ -137,8 +137,8 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient(db.as_salsa_database()).field( - db.as_salsa_database(), + let fields = $Configuration::ingredient(db.as_dyn_database()).field( + db.as_dyn_database(), self, $field_index, ); @@ -157,7 +157,7 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let (ingredient, runtime) = $Configuration::ingredient_mut(db.as_salsa_database_mut()); + let (ingredient, runtime) = $Configuration::ingredient_mut(db.as_dyn_database_mut()); $zalsa::input::SetterImpl::new( runtime, self, @@ -174,7 +174,7 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, { - $Configuration::ingredient(db.as_salsa_database()).get_singleton_input() + $Configuration::ingredient(db.as_dyn_database()).get_singleton_input() } #[track_caller] diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 049229f6..164a383d 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -81,7 +81,7 @@ macro_rules! setup_interned_struct { { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); - CACHE.get_or_create(db.as_salsa_database(), || { + CACHE.get_or_create(db.as_dyn_database(), || { db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) }) } @@ -135,7 +135,7 @@ macro_rules! setup_interned_struct { $Db: ?Sized + salsa::Database, { let current_revision = $zalsa::current_revision(db); - $Configuration::ingredient(db).intern(db.as_salsa_database(), ($($field_id,)*)) + $Configuration::ingredient(db).intern(db.as_dyn_database(), ($($field_id,)*)) } $( diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index d91ba725..e72ec4ed 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -131,7 +131,7 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { - $FN_CACHE.get_or_create(db.as_salsa_database(), || { + $FN_CACHE.get_or_create(db.as_dyn_database(), || { ::zalsa_db(db); db.add_or_lookup_jar_by_type(&$Configuration) }) @@ -141,7 +141,7 @@ macro_rules! setup_tracked_fn { fn intern_ingredient( db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { - $INTERN_CACHE.get_or_create(db.as_salsa_database(), || { + $INTERN_CACHE.get_or_create(db.as_dyn_database(), || { db.add_or_lookup_jar_by_type(&$Configuration).successor(0) }) } @@ -193,7 +193,7 @@ macro_rules! setup_tracked_fn { if $needs_interner { $Configuration::intern_ingredient(db).data(key).clone() } else { - $zalsa::LookupId::lookup_id(key, db.as_salsa_database()) + $zalsa::LookupId::lookup_id(key, db.as_dyn_database()) } } } @@ -233,7 +233,7 @@ macro_rules! setup_tracked_fn { use salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { if $needs_interner { - $Configuration::intern_ingredient($db).intern_id($db.as_salsa_database(), ($($input_id),*)) + $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)) } else { $zalsa::AsId::as_id(&($($input_id),*)) } @@ -268,7 +268,7 @@ macro_rules! setup_tracked_fn { let result = $zalsa::macro_if! { if $needs_interner { { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_salsa_database(), ($($input_id),*)); + let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); $Configuration::fn_ingredient($db).fetch($db, key) } } else { diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index d9f36063..3d566d69 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -192,8 +192,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - $Configuration::ingredient(db.as_salsa_database()).new_struct( - db.as_salsa_database(), + $Configuration::ingredient(db.as_dyn_database()).new_struct( + db.as_dyn_database(), ($($field_id,)*) ) } @@ -204,7 +204,7 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = unsafe { self.0.as_ref() }.field(db.as_salsa_database(), $field_index); + let fields = unsafe { self.0.as_ref() }.field(db.as_dyn_database(), $field_index); $crate::maybe_clone!( $field_option, $field_ty, diff --git a/src/database.rs b/src/database.rs index fc1fe519..9aebe784 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,7 +1,7 @@ use crate::{local_state, storage::DatabaseGen, Durability, Event, Revision}; #[salsa_macros::db] -pub trait Database: DatabaseGen { +pub trait Database: DatabaseGen + AsDynDatabase { /// This function is invoked at key points in the salsa /// runtime. It permits the database to be customized and to /// inject logging or other custom behavior. @@ -31,7 +31,7 @@ pub trait Database: DatabaseGen { /// Queries which report untracked reads will be re-executed in the next /// revision. fn report_untracked_read(&self) { - let db = self.as_salsa_database(); + let db = self.as_dyn_database(); local_state::attach(db, |state| { state.report_untracked_read(db.runtime().current_revision()) }) @@ -46,6 +46,24 @@ pub trait Database: DatabaseGen { } } +/// Upcast to a `dyn Database`. +/// +/// Only required because upcasts not yet stabilized (*grr*). +pub trait AsDynDatabase { + fn as_dyn_database(&self) -> &dyn Database; + fn as_dyn_database_mut(&mut self) -> &mut dyn Database; +} + +impl AsDynDatabase for T { + fn as_dyn_database(&self) -> &dyn Database { + self + } + + fn as_dyn_database_mut(&mut self) -> &mut dyn Database { + self + } +} + pub fn current_revision(db: &Db) -> Revision { db.runtime().current_revision() } diff --git a/src/function.rs b/src/function.rs index 6cb1fd72..351fbb36 100644 --- a/src/function.rs +++ b/src/function.rs @@ -3,13 +3,9 @@ use std::{any::Any, fmt, sync::Arc}; use crossbeam::atomic::AtomicCell; use crate::{ - cycle::CycleRecoveryStrategy, - ingredient::fmt_index, - key::DatabaseKeyIndex, - local_state::QueryOrigin, - salsa_struct::SalsaStructInDb, - storage::{DatabaseGen, IngredientIndex}, - Cycle, Database, Event, EventKind, Id, Revision, + cycle::CycleRecoveryStrategy, ingredient::fmt_index, key::DatabaseKeyIndex, + local_state::QueryOrigin, salsa_struct::SalsaStructInDb, storage::IngredientIndex, + AsDynDatabase as _, Cycle, Database, Event, EventKind, Id, Revision, }; use self::delete::DeletedEntries; @@ -199,7 +195,7 @@ where fn register<'db>(&self, db: &'db C::DbView) { if !self.registered.fetch_or(true) { as SalsaStructInDb>::register_dependent_fn( - db.as_salsa_database(), + db.as_dyn_database(), self.index, ) } diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 8cefa32c..28617040 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -1,5 +1,5 @@ use crate::{ - hash::FxHashSet, key::DependencyIndex, local_state::QueryRevisions, storage::DatabaseGen, + hash::FxHashSet, key::DependencyIndex, local_state::QueryRevisions, AsDynDatabase as _, Database, DatabaseKeyIndex, Event, EventKind, }; @@ -46,6 +46,6 @@ where }, }); - output.remove_stale_output(db.as_salsa_database(), key); + output.remove_stale_output(db.as_dyn_database(), key); } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index c5bbf7f1..09ac76e5 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -4,7 +4,7 @@ use crate::{ local_state::{self, LocalState}, runtime::StampedValue, storage::DatabaseGen, - Id, + AsDynDatabase as _, Id, }; use super::{Configuration, IngredientImpl}; @@ -14,8 +14,8 @@ where C: Configuration, { pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Output<'db> { - local_state::attach(db.as_salsa_database(), |local_state| { - local_state.unwind_if_revision_cancelled(db.as_salsa_database()); + local_state::attach(db.as_dyn_database(), |local_state| { + local_state.unwind_if_revision_cancelled(db.as_dyn_database()); let StampedValue { value, @@ -87,7 +87,7 @@ where // Try to claim this query: if someone else has claimed it already, go back and start again. let _claim_guard = self.sync_map - .claim(db.as_salsa_database(), local_state, database_key_index)?; + .claim(db.as_dyn_database(), local_state, database_key_index)?; // Push the query on the stack. let active_query = local_state.push_query(database_key_index); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 41e57ae9..81b44481 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -5,7 +5,7 @@ use crate::{ local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, storage::DatabaseGen, - Id, Revision, Runtime, + AsDynDatabase as _, Id, Revision, Runtime, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -20,9 +20,9 @@ where key: Id, revision: Revision, ) -> bool { - local_state::attach(db.as_salsa_database(), |local_state| { + local_state::attach(db.as_dyn_database(), |local_state| { let runtime = db.runtime(); - local_state.unwind_if_revision_cancelled(db.as_salsa_database()); + local_state.unwind_if_revision_cancelled(db.as_dyn_database()); loop { let database_key_index = self.database_key_index(key); @@ -63,7 +63,7 @@ where let _claim_guard = self.sync_map - .claim(db.as_salsa_database(), local_state, database_key_index)?; + .claim(db.as_dyn_database(), local_state, database_key_index)?; let active_query = local_state.push_query(database_key_index); // Load the current memo, if any. Use a real arc, not an arc-swap guard, @@ -118,7 +118,7 @@ where if memo.check_durability(runtime) { // No input of the suitable durability has changed since last verified. - let db = db.as_salsa_database(); + let db = db.as_dyn_database(); memo.mark_as_verified(db, runtime, database_key_index); memo.mark_outputs_as_verified(db, database_key_index); return true; @@ -185,7 +185,7 @@ where match edge_kind { EdgeKind::Input => { if dependency_index - .maybe_changed_after(db.as_salsa_database(), last_verified_at) + .maybe_changed_after(db.as_dyn_database(), last_verified_at) { return false; } @@ -208,14 +208,14 @@ where // so even if we mark them as valid here, the function will re-execute // and overwrite the contents. dependency_index - .mark_validated_output(db.as_salsa_database(), database_key_index); + .mark_validated_output(db.as_dyn_database(), database_key_index); } } } } } - old_memo.mark_as_verified(db.as_salsa_database(), runtime, database_key_index); + old_memo.mark_as_verified(db.as_dyn_database(), runtime, database_key_index); true } } diff --git a/src/function/specify.rs b/src/function/specify.rs index 97577d1e..82e2af3d 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -4,7 +4,7 @@ use crate::{ local_state::{self, QueryOrigin, QueryRevisions}, storage::DatabaseGen, tracked_struct::TrackedStructInDb, - Database, DatabaseKeyIndex, Id, + AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -19,7 +19,7 @@ where where C::Input<'db>: TrackedStructInDb, { - local_state::attach(db.as_salsa_database(), |state| { + local_state::attach(db.as_dyn_database(), |state| { let (active_query_key, current_deps) = match state.active_query() { Some(v) => v, None => panic!("can only use `specify` inside a tracked function"), @@ -37,8 +37,7 @@ where // * Q4 invokes Q2 and then Q1 // // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. - let database_key_index = - >::database_key_index(db.as_salsa_database(), key); + let database_key_index = >::database_key_index(db.as_dyn_database(), key); let dependency_index = database_key_index.into(); if !state.is_output_of_active_query(dependency_index) { panic!( @@ -120,6 +119,6 @@ where } let database_key_index = self.database_key_index(key); - memo.mark_as_verified(db.as_salsa_database(), runtime, database_key_index); + memo.mark_as_verified(db.as_dyn_database(), runtime, database_key_index); } } diff --git a/src/lib.rs b/src/lib.rs index 7d1397a1..b15641be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ mod views; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; pub use self::cycle::Cycle; +pub use self::database::AsDynDatabase; pub use self::database::Database; pub use self::durability::Durability; pub use self::event::Event; diff --git a/src/local_state.rs b/src/local_state.rs index 03473928..0bce6b51 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -30,7 +30,7 @@ pub(crate) fn attach(db: &DB, op: impl FnOnce(&LocalState) -> R) -> R where DB: ?Sized + Database, { - LOCAL_STATE.with(|state| state.attach(db.as_salsa_database(), || op(state))) + LOCAL_STATE.with(|state| state.attach(db.as_dyn_database(), || op(state))) } /// Access the "attached" database. Returns `None` if no database is attached. diff --git a/src/storage.rs b/src/storage.rs index e28fc45c..0ae6e194 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -23,33 +23,6 @@ pub fn views(db: &Db) -> &Views { /// This trait is meant to be implemented by our procedural macro. /// We need to document any non-obvious conditions that it satisfies. pub unsafe trait DatabaseGen: Any { - /// Upcast to a `dyn Database`. - /// - /// Only required because upcasts not yet stabilized (*grr*). - /// - /// # Safety - /// - /// Returns the same data pointer as `self`. - fn as_salsa_database(&self) -> &dyn Database; - - /// Upcast to a `dyn Database`. - /// - /// Only required because upcasts not yet stabilized (*grr*). - /// - /// # Safety - /// - /// Returns the same data pointer as `self`. - fn as_salsa_database_mut(&mut self) -> &mut dyn Database; - - /// Upcast to a `dyn DatabaseGen`. - /// - /// Only required because upcasts not yet stabilized (*grr*). - /// - /// # Safety - /// - /// Returns the same data pointer as `self`. - fn as_salsa_database_gen(&self) -> &dyn DatabaseGen; - /// Returns a reference to the underlying. fn views(&self) -> &Views; @@ -98,18 +71,6 @@ pub unsafe trait HasStorage: Database + Sized + Any { } unsafe impl DatabaseGen for T { - fn as_salsa_database(&self) -> &dyn Database { - self - } - - fn as_salsa_database_mut(&mut self) -> &mut dyn Database { - self - } - - fn as_salsa_database_gen(&self) -> &dyn DatabaseGen { - self - } - fn views(&self) -> &Views { &self.storage().upcasts } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 131a319c..08d8119c 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -471,7 +471,7 @@ where // `executor` creates a tracked struct `salsa_output_key`, // but it did not in the current revision. // In that case, we can delete `stale_output_key` and any data associated with it. - self.delete_entity(db.as_salsa_database(), stale_output_key.unwrap()); + self.delete_entity(db.as_dyn_database(), stale_output_key.unwrap()); } fn requires_reset_for_new_revision(&self) -> bool { From 596461c2135b55a4d2ec6b4154fcd5119386d84c Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 10:50:28 +0000 Subject: [PATCH 03/29] hide internal methods behind a Zalsa trait The traits are now quite simple: * Database is the external trait * ZalsaDatabase is the internal one, implemented by `#[salsa::db]`. It adds two methods, `zalsa` and `zalsa_mut`. Those give access to our internal methods. For now I've hidden the methods behind `&dyn Zalsa`. This is nice and clean but it may be worth later refactoring to a `struct Zalsa`. --- .../src/setup_accumulator_impl.rs | 2 +- .../src/setup_input_struct.rs | 7 +- .../src/setup_interned_struct.rs | 3 +- .../salsa-macro-rules/src/setup_tracked_fn.rs | 4 +- .../src/setup_tracked_struct.rs | 4 +- components/salsa-macros/src/db.rs | 10 +- src/accumulator.rs | 9 +- src/database.rs | 22 +- src/function.rs | 4 +- src/function/accumulated.rs | 9 +- src/function/execute.rs | 5 +- src/function/fetch.rs | 4 +- src/function/maybe_changed_after.rs | 6 +- src/function/specify.rs | 6 +- src/function/sync.rs | 2 +- src/handle.rs | 16 +- src/key.rs | 11 +- src/lib.rs | 3 +- src/local_state.rs | 2 +- src/runtime.rs | 1 + src/storage.rs | 189 +++++++----------- src/tracked_struct.rs | 14 +- src/tracked_struct/tracked_field.rs | 4 +- src/views.rs | 7 - 24 files changed, 159 insertions(+), 185 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index cf36637c..f10318a2 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -24,7 +24,7 @@ macro_rules! setup_accumulator_impl { fn $ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<$Struct> { $CACHE.get_or_create(db, || { - db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Struct>>::default()) + db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Struct>>::default()) }) } diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index f320656c..10257b6d 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -82,13 +82,14 @@ macro_rules! setup_input_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db, || { - db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) }) } pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { - let index = db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); - let (ingredient, runtime) = db.lookup_ingredient_mut(index); + let zalsa_mut = db.zalsa_mut(); + let index = zalsa_mut.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); + let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); (ingredient, runtime) } diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 164a383d..1d2c868d 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -82,7 +82,7 @@ macro_rules! setup_interned_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db.as_dyn_database(), || { - db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) }) } } @@ -144,7 +144,6 @@ macro_rules! setup_interned_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let runtime = db.runtime(); let fields = $Configuration::ingredient(db).fields(self); $zalsa::maybe_clone!( $field_option, diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index e72ec4ed..e1165516 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -133,7 +133,7 @@ macro_rules! setup_tracked_fn { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { $FN_CACHE.get_or_create(db.as_dyn_database(), || { ::zalsa_db(db); - db.add_or_lookup_jar_by_type(&$Configuration) + db.zalsa().add_or_lookup_jar_by_type(&$Configuration) }) } @@ -142,7 +142,7 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { $INTERN_CACHE.get_or_create(db.as_dyn_database(), || { - db.add_or_lookup_jar_by_type(&$Configuration).successor(0) + db.zalsa().add_or_lookup_jar_by_type(&$Configuration).successor(0) }) } } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 3d566d69..21524f8a 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -134,14 +134,14 @@ macro_rules! setup_tracked_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db, || { - db.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl::<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl::<$Configuration>>::default()) }) } } impl<$db_lt> $zalsa::LookupId<$db_lt> for $Struct<$db_lt> { fn lookup_id(id: salsa::Id, db: &$db_lt dyn $zalsa::Database) -> Self { - $Configuration::ingredient(db).lookup_struct(db.runtime(), id) + $Configuration::ingredient(db).lookup_struct(db, id) } } diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index 02fcef4b..b7057613 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -33,7 +33,7 @@ impl DbMacro { fn try_db(self, input: syn::Item) -> syn::Result { match input { syn::Item::Struct(input) => { - let has_storage_impl = self.has_storage_impl(&input)?; + let has_storage_impl = self.zalsa_database_impl(&input)?; Ok(quote! { #has_storage_impl #input @@ -79,7 +79,7 @@ impl DbMacro { )) } - fn has_storage_impl(&self, input: &syn::ItemStruct) -> syn::Result { + fn zalsa_database_impl(&self, input: &syn::ItemStruct) -> syn::Result { let storage = self.find_storage_field(input)?; let db = &input.ident; let zalsa = self.hygiene.ident("zalsa"); @@ -88,12 +88,12 @@ impl DbMacro { const _: () = { use salsa::plumbing as #zalsa; - unsafe impl #zalsa::HasStorage for #db { - fn storage(&self) -> &#zalsa::Storage { + unsafe impl #zalsa::ZalsaDatabase for #db { + fn zalsa(&self) -> &dyn #zalsa::Zalsa { &self.#storage } - fn storage_mut(&mut self) -> &mut #zalsa::Storage { + fn zalsa_mut(&mut self) -> &mut dyn #zalsa::Zalsa { &mut self.#storage } } diff --git a/src/accumulator.rs b/src/accumulator.rs index a3ebc2ff..f00be0d8 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -59,8 +59,9 @@ impl IngredientImpl { Db: ?Sized + Database, { let jar: JarImpl = Default::default(); - let index = db.add_or_lookup_jar_by_type(&jar); - let ingredient = db.lookup_ingredient(index).assert_type::(); + let zalsa = db.zalsa(); + let index = zalsa.add_or_lookup_jar_by_type(&jar); + let ingredient = zalsa.lookup_ingredient(index).assert_type::(); Some(ingredient) } @@ -80,7 +81,7 @@ impl IngredientImpl { pub fn push(&self, db: &dyn crate::Database, value: A) { local_state::attach(db, |state| { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); let current_revision = runtime.current_revision(); let (active_query, _) = match state.active_query() { Some(pair) => pair, @@ -162,7 +163,7 @@ impl Ingredient for IngredientImpl { output_key: Option, ) { assert!(output_key.is_none()); - let current_revision = db.runtime().current_revision(); + let current_revision = db.zalsa().runtime().current_revision(); if let Some(mut v) = self.map.get_mut(&executor) { // The value is still valid in the new revision. v.produced_at = current_revision; diff --git a/src/database.rs b/src/database.rs index 9aebe784..5d01801b 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,7 +1,7 @@ -use crate::{local_state, storage::DatabaseGen, Durability, Event, Revision}; +use crate::{local_state, storage::ZalsaDatabase, Durability, Event, Revision}; #[salsa_macros::db] -pub trait Database: DatabaseGen + AsDynDatabase { +pub trait Database: ZalsaDatabase + AsDynDatabase { /// This function is invoked at key points in the salsa /// runtime. It permits the database to be customized and to /// inject logging or other custom behavior. @@ -21,7 +21,7 @@ pub trait Database: DatabaseGen + AsDynDatabase { /// will block until that snapshot is dropped -- if that snapshot /// is owned by the current thread, this could trigger deadlock. fn synthetic_write(&mut self, durability: Durability) { - let runtime = self.runtime_mut(); + let runtime = self.zalsa_mut().runtime_mut(); runtime.new_revision(); runtime.report_tracked_write(durability); } @@ -33,7 +33,7 @@ pub trait Database: DatabaseGen + AsDynDatabase { fn report_untracked_read(&self) { let db = self.as_dyn_database(); local_state::attach(db, |state| { - state.report_untracked_read(db.runtime().current_revision()) + state.report_untracked_read(db.zalsa().runtime().current_revision()) }) } @@ -65,5 +65,17 @@ impl AsDynDatabase for T { } pub fn current_revision(db: &Db) -> Revision { - db.runtime().current_revision() + db.zalsa().runtime().current_revision() +} + +impl dyn Database { + /// Upcasts `self` to the given view. + /// + /// # Panics + /// + /// If the view has not been added to the database (see [`DatabaseView`][]) + #[track_caller] + pub fn as_view(&self) -> &DbView { + self.zalsa().views().try_view_as(self).unwrap() + } } diff --git a/src/function.rs b/src/function.rs index 351fbb36..3c95df59 100644 --- a/src/function.rs +++ b/src/function.rs @@ -271,8 +271,10 @@ where // Anything that was output by this memoized execution // is now itself stale. + let zalsa = db.zalsa(); for stale_output in origin.outputs() { - db.lookup_ingredient(stale_output.ingredient_index) + zalsa + .lookup_ingredient(stale_output.ingredient_index) .remove_stale_output(db, key, stale_output.key_index); } } diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 8ff145a3..e533fb28 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,5 +1,5 @@ use crate::{ - accumulator, hash::FxHashSet, local_state, storage::DatabaseGen, DatabaseKeyIndex, Id, + accumulator, hash::FxHashSet, local_state, storage::ZalsaDatabase as _, DatabaseKeyIndex, Id, }; use super::{Configuration, IngredientImpl}; @@ -15,7 +15,8 @@ where A: accumulator::Accumulator, { local_state::attach(db, |local_state| { - let current_revision = db.runtime().current_revision(); + let zalsa = db.zalsa(); + let current_revision = zalsa.runtime().current_revision(); let Some(accumulator) = >::from_db(db) else { return vec![]; @@ -33,7 +34,9 @@ where if visited.insert(k) { accumulator.produced_by(current_revision, local_state, k, &mut output); - let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); + let origin = zalsa + .lookup_ingredient(k.ingredient_index) + .origin(k.key_index); let inputs = origin.iter().flat_map(|origin| origin.inputs()); // Careful: we want to push in execution order, so reverse order to // ensure the first child that was executed will be the first child popped diff --git a/src/function/execute.rs b/src/function/execute.rs index d1e02fb9..f564c3ce 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - local_state::ActiveQueryGuard, runtime::StampedValue, storage::DatabaseGen, Cycle, Database, + local_state::ActiveQueryGuard, runtime::StampedValue, storage::ZalsaDatabase, Cycle, Database, Event, EventKind, }; @@ -26,7 +26,8 @@ where active_query: ActiveQueryGuard<'_>, opt_old_memo: Option>>>, ) -> StampedValue<&C::Output<'db>> { - let runtime = db.runtime(); + let zalsa = db.zalsa(); + let runtime = zalsa.runtime(); let revision_now = runtime.current_revision(); let database_key_index = active_query.database_key_index; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 09ac76e5..5b25a70d 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -3,7 +3,7 @@ use arc_swap::Guard; use crate::{ local_state::{self, LocalState}, runtime::StampedValue, - storage::DatabaseGen, + storage::ZalsaDatabase as _, AsDynDatabase as _, Id, }; @@ -63,7 +63,7 @@ where let memo_guard = self.memo_map.get(key); if let Some(memo) = &memo_guard { if memo.value.is_some() { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); if self.shallow_verify_memo(db, runtime, self.database_key_index(key), memo) { let value = unsafe { // Unsafety invariant: memo is present in memo_map diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 81b44481..80ec605a 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -4,7 +4,7 @@ use crate::{ key::DatabaseKeyIndex, local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, - storage::DatabaseGen, + storage::ZalsaDatabase as _, AsDynDatabase as _, Id, Revision, Runtime, }; @@ -21,7 +21,7 @@ where revision: Revision, ) -> bool { local_state::attach(db.as_dyn_database(), |local_state| { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); local_state.unwind_if_revision_cancelled(db.as_dyn_database()); loop { @@ -141,7 +141,7 @@ where old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> bool { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); let database_key_index = active_query.database_key_index; tracing::debug!("{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})",); diff --git a/src/function/specify.rs b/src/function/specify.rs index 82e2af3d..9d2a3dd2 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -2,7 +2,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ local_state::{self, QueryOrigin, QueryRevisions}, - storage::DatabaseGen, + storage::ZalsaDatabase, tracked_struct::TrackedStructInDb, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; @@ -64,7 +64,7 @@ where // - a result that is verified in the current revision, because it was set, which will use the set value // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) - let revision = db.runtime().current_revision(); + let revision = db.zalsa().runtime().current_revision(); let mut revisions = QueryRevisions { changed_at: current_deps.changed_at, durability: current_deps.durability, @@ -101,7 +101,7 @@ where executor: DatabaseKeyIndex, key: Id, ) { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); let memo = match self.memo_map.get(key) { Some(m) => m, diff --git a/src/function/sync.rs b/src/function/sync.rs index f9b1d1fc..d59cc7e5 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -28,7 +28,7 @@ impl SyncMap { local_state: &LocalState, database_key_index: DatabaseKeyIndex, ) -> Option> { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); let thread_id = std::thread::current().id(); match self.sync_map.entry(database_key_index.key_index) { dashmap::mapref::entry::Entry::Vacant(entry) => { diff --git a/src/handle.rs b/src/handle.rs index 9469e5e9..6b4dd167 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -2,13 +2,13 @@ use std::sync::Arc; use parking_lot::{Condvar, Mutex}; -use crate::{storage::HasStorage, Event, EventKind}; +use crate::{Database, Event, EventKind}; /// A database "handle" allows coordination of multiple async tasks accessing the same database. /// So long as you are just doing reads, you can freely clone. /// When you attempt to modify the database, you call `get_mut`, which will set the cancellation flag, /// causing other handles to get panics. Once all other handles are dropped, you can proceed. -pub struct Handle { +pub struct Handle { /// Reference to the database. This is always `Some` except during destruction. db: Option>, @@ -23,7 +23,7 @@ struct Coordinate { cvar: Condvar, } -impl Handle { +impl Handle { /// Create a new handle wrapping `db`. pub fn new(db: Db) -> Self { Self { @@ -77,8 +77,8 @@ impl Handle { /// This could deadlock if there is a single worker with two handles to the /// same database! fn cancel_others(&mut self) { - let storage = self.db().storage(); - storage.runtime().set_cancellation_flag(); + let zalsa = self.db().zalsa(); + zalsa.runtime().set_cancellation_flag(); self.db().salsa_event(Event { thread_id: std::thread::current().id(), @@ -94,7 +94,7 @@ impl Handle { // ANCHOR_END: cancel_other_workers } -impl Drop for Handle { +impl Drop for Handle { fn drop(&mut self) { // Drop the database handle *first* self.db.take(); @@ -105,7 +105,7 @@ impl Drop for Handle { } } -impl std::ops::Deref for Handle { +impl std::ops::Deref for Handle { type Target = Db; fn deref(&self) -> &Self::Target { @@ -113,7 +113,7 @@ impl std::ops::Deref for Handle { } } -impl Clone for Handle { +impl Clone for Handle { fn clone(&self) -> Self { *self.coordinate.clones.lock() += 1; diff --git a/src/key.rs b/src/key.rs index e3e43c07..b2b70292 100644 --- a/src/key.rs +++ b/src/key.rs @@ -32,7 +32,8 @@ impl DependencyIndex { } pub(crate) fn remove_stale_output(&self, db: &dyn Database, executor: DatabaseKeyIndex) { - db.lookup_ingredient(self.ingredient_index) + db.zalsa() + .lookup_ingredient(self.ingredient_index) .remove_stale_output(db, executor, self.key_index) } @@ -41,7 +42,8 @@ impl DependencyIndex { db: &dyn Database, database_key_index: DatabaseKeyIndex, ) { - db.lookup_ingredient(self.ingredient_index) + db.zalsa() + .lookup_ingredient(self.ingredient_index) .mark_validated_output(db, database_key_index, self.key_index) } @@ -50,7 +52,8 @@ impl DependencyIndex { db: &dyn Database, last_verified_at: crate::Revision, ) -> bool { - db.lookup_ingredient(self.ingredient_index) + db.zalsa() + .lookup_ingredient(self.ingredient_index) .maybe_changed_after(db, self.key_index, last_verified_at) } } @@ -58,7 +61,7 @@ impl DependencyIndex { impl std::fmt::Debug for DependencyIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { local_state::with_attached_database(|db| { - let ingredient = db.lookup_ingredient(self.ingredient_index); + let ingredient = db.zalsa().lookup_ingredient(self.ingredient_index); ingredient.fmt_index(self.key_index, f) }) .unwrap_or_else(|| { diff --git a/src/lib.rs b/src/lib.rs index b15641be..395938dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,10 +98,11 @@ pub mod plumbing { pub use crate::runtime::StampedValue; pub use crate::salsa_struct::SalsaStructInDb; pub use crate::storage::views; - pub use crate::storage::HasStorage; pub use crate::storage::IngredientCache; pub use crate::storage::IngredientIndex; pub use crate::storage::Storage; + pub use crate::storage::Zalsa; + pub use crate::storage::ZalsaDatabase; pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::always_update; pub use crate::update::helper::Dispatch as UpdateDispatch; diff --git a/src/local_state.rs b/src/local_state.rs index 0bce6b51..6e688e21 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -321,7 +321,7 @@ impl LocalState { /// `salsa_event` is emitted when this method is called, so that should be /// used instead. pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); let thread_id = std::thread::current().id(); db.salsa_event(Event { thread_id, diff --git a/src/runtime.rs b/src/runtime.rs index 51c8561d..04b83645 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -310,6 +310,7 @@ impl Runtime { aqs.iter_mut() .skip_while(|aq| { match db + .zalsa() .lookup_ingredient(aq.database_key_index.ingredient_index) .cycle_recovery_strategy() { diff --git a/src/storage.rs b/src/storage.rs index 0ae6e194..3139cbd6 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -12,7 +12,7 @@ use crate::views::{Views, ViewsOf}; use crate::Database; pub fn views(db: &Db) -> &Views { - DatabaseGen::views(db) + db.zalsa().views() } /// Salsa database methods whose implementation is generated by @@ -22,7 +22,19 @@ pub fn views(db: &Db) -> &Views { /// /// This trait is meant to be implemented by our procedural macro. /// We need to document any non-obvious conditions that it satisfies. -pub unsafe trait DatabaseGen: Any { +pub unsafe trait ZalsaDatabase: Any { + /// Plumbing methods. + #[doc(hidden)] + fn zalsa(&self) -> &dyn Zalsa; + + #[doc(hidden)] + fn zalsa_mut(&mut self) -> &mut dyn Zalsa; +} + +/// The "plumbing interface" to the Salsa database. +/// +/// **NOT SEMVER STABLE.** +pub trait Zalsa { /// Returns a reference to the underlying. fn views(&self) -> &Views; @@ -56,66 +68,82 @@ pub unsafe trait DatabaseGen: Any { fn runtime_mut(&mut self) -> &mut Runtime; } -/// This is the *actual* trait that the macro generates. -/// It simply gives access to the internal storage. -/// Note that it is NOT a supertrait of `Database` -/// because it is not `dyn`-safe. -/// -/// # Safety -/// -/// The `storage` field must be an owned field of -/// the implementing struct. -pub unsafe trait HasStorage: Database + Sized + Any { - fn storage(&self) -> &Storage; - fn storage_mut(&mut self) -> &mut Storage; -} - -unsafe impl DatabaseGen for T { +impl Zalsa for Storage { fn views(&self) -> &Views { - &self.storage().upcasts + &self.upcasts } fn nonce(&self) -> Nonce { - self.storage().nonce + self.nonce } fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { - self.storage().lookup_jar_by_type(jar) + self.jar_map.lock().get(&jar.type_id()).copied() } fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { - self.storage().add_or_lookup_jar_by_type(jar) + { + let jar_type_id = jar.type_id(); + let mut jar_map = self.jar_map.lock(); + *jar_map + .entry(jar_type_id) + .or_insert_with(|| { + let index = IngredientIndex::from(self.ingredients_vec.len()); + let ingredients = jar.create_ingredients(index); + for ingredient in ingredients { + let expected_index = ingredient.ingredient_index(); + + if ingredient.requires_reset_for_new_revision() { + self.ingredients_requiring_reset.push(expected_index); + } + + let actual_index = self + .ingredients_vec + .push(ingredient); + assert_eq!( + expected_index.as_usize(), + actual_index, + "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", + self.ingredients_vec.get(actual_index).unwrap(), + expected_index, + actual_index, + ); + + } + index + }) + } } fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { - self.storage().lookup_ingredient(index) + &**self.ingredients_vec.get(index.as_usize()).unwrap() } fn runtime(&self) -> &Runtime { - &self.storage().runtime + &self.runtime } fn runtime_mut(&mut self) -> &mut Runtime { - &mut self.storage_mut().runtime + &mut self.runtime } fn lookup_ingredient_mut( &mut self, index: IngredientIndex, ) -> (&mut dyn Ingredient, &mut Runtime) { - self.storage_mut().lookup_ingredient_mut(index) - } -} + self.runtime.new_revision(); -impl dyn Database { - /// Upcasts `self` to the given view. - /// - /// # Panics - /// - /// If the view has not been added to the database (see [`DatabaseView`][]) - #[track_caller] - pub fn as_view(&self) -> &DbView { - self.views().try_view_as(self).unwrap() + for index in self.ingredients_requiring_reset.iter() { + self.ingredients_vec + .get_mut(index.as_usize()) + .unwrap() + .reset_for_new_revision(); + } + + ( + &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap(), + &mut self.runtime, + ) } } @@ -145,7 +173,7 @@ impl IngredientIndex { } pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - db.lookup_ingredient(self).cycle_recovery_strategy() + db.zalsa().lookup_ingredient(self).cycle_recovery_strategy() } pub fn successor(self, index: usize) -> Self { @@ -154,7 +182,7 @@ impl IngredientIndex { /// Return the "debug name" of this ingredient (e.g., the name of the tracked struct it represents) pub(crate) fn debug_name(self, db: &dyn Database) -> &'static str { - db.lookup_ingredient(self).debug_name() + db.zalsa().lookup_ingredient(self).debug_name() } } @@ -201,79 +229,6 @@ impl Default for Storage { } // ANCHOR_END: default -impl Storage { - /// Add an upcast function to type `T`. - pub fn add_upcast(&mut self, func: fn(&Db) -> &T) { - self.upcasts.add::(func) - } - - /// Adds the ingredients in `jar` to the database if not already present. - /// If a jar of this type is already present, returns the index. - fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { - let jar_type_id = jar.type_id(); - let mut jar_map = self.jar_map.lock(); - *jar_map - .entry(jar_type_id) - .or_insert_with(|| { - let index = IngredientIndex::from(self.ingredients_vec.len()); - let ingredients = jar.create_ingredients(index); - for ingredient in ingredients { - let expected_index = ingredient.ingredient_index(); - - if ingredient.requires_reset_for_new_revision() { - self.ingredients_requiring_reset.push(expected_index); - } - - let actual_index = self - .ingredients_vec - .push(ingredient); - assert_eq!( - expected_index.as_usize(), - actual_index, - "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", - self.ingredients_vec.get(actual_index).unwrap(), - expected_index, - actual_index, - ); - - } - index - }) - } - - /// Return the index of the 1st ingredient from the given jar. - pub fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { - self.jar_map.lock().get(&jar.type_id()).copied() - } - - pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { - &**self.ingredients_vec.get(index.as_usize()).unwrap() - } - - fn lookup_ingredient_mut( - &mut self, - index: IngredientIndex, - ) -> (&mut dyn Ingredient, &mut Runtime) { - self.runtime.new_revision(); - - for index in self.ingredients_requiring_reset.iter() { - self.ingredients_vec - .get_mut(index.as_usize()) - .unwrap() - .reset_for_new_revision(); - } - - ( - &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap(), - &mut self.runtime, - ) - } - - pub fn runtime(&self) -> &Runtime { - &self.runtime - } -} - /// Caches a pointer to an ingredient in a database. /// Optimized for the case of a single database. pub struct IngredientCache @@ -309,18 +264,18 @@ where /// If the ingredient is not already in the cache, it will be created. pub fn get_or_create<'s>( &self, - storage: &'s dyn Database, + db: &'s dyn Database, create_index: impl Fn() -> IngredientIndex, ) -> &'s I { let &(nonce, ingredient) = self.cached_data.get_or_init(|| { - let ingredient = self.create_ingredient(storage, &create_index); - (storage.nonce(), ingredient as *const I) + let ingredient = self.create_ingredient(db, &create_index); + (db.zalsa().nonce(), ingredient as *const I) }); - if storage.nonce() == nonce { + if db.zalsa().nonce() == nonce { unsafe { &*ingredient } } else { - self.create_ingredient(storage, &create_index) + self.create_ingredient(db, &create_index) } } @@ -330,6 +285,6 @@ where create_index: &impl Fn() -> IngredientIndex, ) -> &'s I { let index = create_index(); - storage.lookup_ingredient(index).assert_type::() + storage.zalsa().lookup_ingredient(index).assert_type::() } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 08d8119c..78bba1d8 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -12,7 +12,6 @@ use crate::{ ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, local_state::{self, QueryOrigin}, - runtime::Runtime, salsa_struct::SalsaStructInDb, storage::IngredientIndex, Database, Durability, Event, Id, Revision, @@ -292,6 +291,8 @@ where fields: C::Fields<'db>, ) -> C::Struct<'db> { local_state::attach(db, |local_state| { + let zalsa = db.zalsa(); + let data_hash = crate::hash::hash(&C::id_fields(&fields)); let (query_key, current_deps, disambiguator) = @@ -306,7 +307,7 @@ where let (id, new_id) = self.intern(entity_key); local_state.add_output(self.database_key_index(id).into()); - let current_revision = db.runtime().current_revision(); + let current_revision = zalsa.runtime().current_revision(); if new_id { // This is a new tracked struct, so create an entry in the struct map. @@ -377,8 +378,8 @@ where /// # Panics /// /// If the struct has not been created in this revision. - pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { - let current_revision = runtime.current_revision(); + pub fn lookup_struct<'db>(&'db self, db: &'db dyn Database, id: Id) -> C::Struct<'db> { + let current_revision = db.zalsa().runtime().current_revision(); self.struct_map.get(current_revision, id) } @@ -405,7 +406,8 @@ where } for dependent_fn in self.dependent_fns.iter() { - db.lookup_ingredient(dependent_fn) + db.zalsa() + .lookup_ingredient(dependent_fn) .salsa_struct_deleted(db, id); } } @@ -456,7 +458,7 @@ where _executor: DatabaseKeyIndex, output_key: Option, ) { - let runtime = db.runtime(); + let runtime = db.zalsa().runtime(); let output_key = output_key.unwrap(); self.struct_map.validate(runtime, output_key); } diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 45e796ea..a10eb31f 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -48,7 +48,7 @@ where /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { local_state::attach(db, |local_state| { - let current_revision = db.runtime().current_revision(); + let current_revision = db.zalsa().runtime().current_revision(); let data = self.struct_map.get(current_revision, id); let data = C::deref_struct(data); let changed_at = data.revisions[self.field_index]; @@ -85,7 +85,7 @@ where input: Option, revision: crate::Revision, ) -> bool { - let current_revision = db.runtime().current_revision(); + let current_revision = db.zalsa().runtime().current_revision(); let id = input.unwrap(); let data = self.struct_map.get(current_revision, id); let data = C::deref_struct(data); diff --git a/src/views.rs b/src/views.rs index 5d0c7f47..75369e75 100644 --- a/src/views.rs +++ b/src/views.rs @@ -38,13 +38,6 @@ impl Default for ViewsOf { } } -impl ViewsOf { - /// Add a new upcast from `Db` to `T`, given the upcasting function `func`. - pub fn add(&self, func: fn(&Db) -> &DbView) { - self.upcasts.add(func); - } -} - impl Deref for ViewsOf { type Target = Views; From 1842b1dfbb8d73ab1b2605906ca29322eadb1a92 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 10:53:15 +0000 Subject: [PATCH 04/29] (almost) encansulate Runtime into Zalsa The distinction is dumb and should go away. But we still need it for a bit. --- Cargo.toml | 1 + .../src/setup_input_struct.rs | 6 +- src/accumulator.rs | 5 +- src/active_query.rs | 7 +- src/database.rs | 10 +-- src/function/accumulated.rs | 2 +- src/function/execute.rs | 14 +--- src/function/fetch.rs | 4 +- src/function/maybe_changed_after.rs | 26 ++++--- src/function/memo.rs | 12 +-- src/function/specify.rs | 10 ++- src/function/sync.rs | 2 +- src/handle.rs | 2 +- src/input.rs | 6 +- src/input/setter.rs | 12 +-- src/local_state.rs | 18 +++-- src/runtime.rs | 18 +---- src/storage.rs | 75 ++++++++++++++----- src/tracked_struct.rs | 8 +- src/tracked_struct/struct_map.rs | 5 +- src/tracked_struct/tracked_field.rs | 4 +- 21 files changed, 136 insertions(+), 111 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 82dd0e6d..611aa69b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ rustc-hash = "2.0.0" salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macros = { path = "components/salsa-macros" } smallvec = "1.0.0" +lazy_static = "1.5.0" [dev-dependencies] annotate-snippets = "0.11.4" diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 10257b6d..ccdc60c9 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -86,12 +86,12 @@ macro_rules! setup_input_struct { }) } - pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { + pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, $zalsa::Revision) { let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); - let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); + let (ingredient, current_revision) = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); - (ingredient, runtime) + (ingredient, current_revision) } } diff --git a/src/accumulator.rs b/src/accumulator.rs index f00be0d8..1e9c368d 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -81,8 +81,7 @@ impl IngredientImpl { pub fn push(&self, db: &dyn crate::Database, value: A) { local_state::attach(db, |state| { - let runtime = db.zalsa().runtime(); - let current_revision = runtime.current_revision(); + let current_revision = db.zalsa().current_revision(); let (active_query, _) = match state.active_query() { Some(pair) => pair, None => { @@ -163,7 +162,7 @@ impl Ingredient for IngredientImpl { output_key: Option, ) { assert!(output_key.is_none()); - let current_revision = db.zalsa().runtime().current_revision(); + let current_revision = db.zalsa().current_revision(); if let Some(mut v) = self.map.get_mut(&executor) { // The value is still valid in the new revision. v.produced_at = current_revision; diff --git a/src/active_query.rs b/src/active_query.rs index 46ef8f02..8f575b75 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -2,8 +2,9 @@ use crate::{ durability::Durability, hash::{FxIndexMap, FxIndexSet}, key::{DatabaseKeyIndex, DependencyIndex}, + local_state::EMPTY_DEPENDENCIES, tracked_struct::Disambiguator, - Cycle, Revision, Runtime, + Cycle, Revision, }; use super::local_state::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; @@ -86,9 +87,9 @@ impl ActiveQuery { self.input_outputs.contains(&(EdgeKind::Output, key)) } - pub(crate) fn revisions(&self, runtime: &Runtime) -> QueryRevisions { + pub(crate) fn revisions(&self) -> QueryRevisions { let input_outputs = if self.input_outputs.is_empty() { - runtime.empty_dependencies() + EMPTY_DEPENDENCIES.clone() } else { self.input_outputs.iter().copied().collect() }; diff --git a/src/database.rs b/src/database.rs index 5d01801b..ba39cf66 100644 --- a/src/database.rs +++ b/src/database.rs @@ -21,9 +21,9 @@ pub trait Database: ZalsaDatabase + AsDynDatabase { /// will block until that snapshot is dropped -- if that snapshot /// is owned by the current thread, this could trigger deadlock. fn synthetic_write(&mut self, durability: Durability) { - let runtime = self.zalsa_mut().runtime_mut(); - runtime.new_revision(); - runtime.report_tracked_write(durability); + let zalsa_mut = self.zalsa_mut(); + zalsa_mut.new_revision(); + zalsa_mut.report_tracked_write(durability); } /// Reports that the query depends on some state unknown to salsa. @@ -33,7 +33,7 @@ pub trait Database: ZalsaDatabase + AsDynDatabase { fn report_untracked_read(&self) { let db = self.as_dyn_database(); local_state::attach(db, |state| { - state.report_untracked_read(db.zalsa().runtime().current_revision()) + state.report_untracked_read(db.zalsa().current_revision()) }) } @@ -65,7 +65,7 @@ impl AsDynDatabase for T { } pub fn current_revision(db: &Db) -> Revision { - db.zalsa().runtime().current_revision() + db.zalsa().current_revision() } impl dyn Database { diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index e533fb28..d1fc003d 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -16,7 +16,7 @@ where { local_state::attach(db, |local_state| { let zalsa = db.zalsa(); - let current_revision = zalsa.runtime().current_revision(); + let current_revision = zalsa.current_revision(); let Some(accumulator) = >::from_db(db) else { return vec![]; diff --git a/src/function/execute.rs b/src/function/execute.rs index f564c3ce..7cbbf62d 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -27,8 +27,7 @@ where opt_old_memo: Option>>>, ) -> StampedValue<&C::Output<'db>> { let zalsa = db.zalsa(); - let runtime = zalsa.runtime(); - let revision_now = runtime.current_revision(); + let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; tracing::info!("{:?}: executing query", database_key_index); @@ -68,16 +67,7 @@ where } } }; - let mut revisions = active_query.pop(runtime); - - // We assume that query is side-effect free -- that is, does - // not mutate the "inputs" to the query system. Sanity check - // that assumption here, at least to the best of our ability. - assert_eq!( - runtime.current_revision(), - revision_now, - "revision altered during query execution", - ); + let mut revisions = active_query.pop(); // If the new value is equal to the old one, then it didn't // really change, even if some of its inputs have. So we can diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 5b25a70d..2474acbe 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -63,8 +63,8 @@ where let memo_guard = self.memo_map.get(key); if let Some(memo) = &memo_guard { if memo.value.is_some() { - let runtime = db.zalsa().runtime(); - if self.shallow_verify_memo(db, runtime, self.database_key_index(key), memo) { + let zalsa = db.zalsa(); + if self.shallow_verify_memo(db, zalsa, self.database_key_index(key), memo) { let value = unsafe { // Unsafety invariant: memo is present in memo_map self.extend_memo_lifetime(memo).unwrap() diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 80ec605a..a9426cf7 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -4,8 +4,8 @@ use crate::{ key::DatabaseKeyIndex, local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, - storage::ZalsaDatabase as _, - AsDynDatabase as _, Id, Revision, Runtime, + storage::{Zalsa, ZalsaDatabase as _}, + AsDynDatabase as _, Id, Revision, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -21,7 +21,7 @@ where revision: Revision, ) -> bool { local_state::attach(db.as_dyn_database(), |local_state| { - let runtime = db.zalsa().runtime(); + let zalsa = db.zalsa(); local_state.unwind_if_revision_cancelled(db.as_dyn_database()); loop { @@ -34,7 +34,7 @@ where // Check if we have a verified version: this is the hot path. let memo_guard = self.memo_map.get(key); if let Some(memo) = &memo_guard { - if self.shallow_verify_memo(db, runtime, database_key_index, memo) { + if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { return memo.revisions.changed_at > revision; } drop(memo_guard); // release the arc-swap guard before cold path @@ -102,12 +102,12 @@ where pub(super) fn shallow_verify_memo( &self, db: &C::DbView, - runtime: &Runtime, + zalsa: &dyn Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { let verified_at = memo.verified_at.load(); - let revision_now = runtime.current_revision(); + let revision_now = zalsa.current_revision(); tracing::debug!("{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})",); @@ -116,10 +116,10 @@ where return true; } - if memo.check_durability(runtime) { + if memo.check_durability(zalsa) { // No input of the suitable durability has changed since last verified. let db = db.as_dyn_database(); - memo.mark_as_verified(db, runtime, database_key_index); + memo.mark_as_verified(db, revision_now, database_key_index); memo.mark_outputs_as_verified(db, database_key_index); return true; } @@ -141,12 +141,12 @@ where old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> bool { - let runtime = db.zalsa().runtime(); + let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; tracing::debug!("{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})",); - if self.shallow_verify_memo(db, runtime, database_key_index, old_memo) { + if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) { return true; } @@ -215,7 +215,11 @@ where } } - old_memo.mark_as_verified(db.as_dyn_database(), runtime, database_key_index); + old_memo.mark_as_verified( + db.as_dyn_database(), + zalsa.current_revision(), + database_key_index, + ); true } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 074fcf1a..ef26a014 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -4,8 +4,8 @@ use arc_swap::{ArcSwap, Guard}; use crossbeam::atomic::AtomicCell; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, local_state::QueryRevisions, Event, EventKind, Id, - Revision, Runtime, + hash::FxDashMap, key::DatabaseKeyIndex, local_state::QueryRevisions, storage::Zalsa, Event, + EventKind, Id, Revision, }; use super::Configuration; @@ -129,8 +129,8 @@ impl Memo { } } /// True if this memo is known not to have changed based on its durability. - pub(super) fn check_durability(&self, runtime: &Runtime) -> bool { - let last_changed = runtime.last_changed_revision(self.revisions.durability); + pub(super) fn check_durability(&self, zalsa: &dyn Zalsa) -> bool { + let last_changed = zalsa.last_changed_revision(self.revisions.durability); let verified_at = self.verified_at.load(); tracing::debug!( "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", @@ -146,7 +146,7 @@ impl Memo { pub(super) fn mark_as_verified( &self, db: &dyn crate::Database, - runtime: &crate::Runtime, + revision_now: Revision, database_key_index: DatabaseKeyIndex, ) { db.salsa_event(Event { @@ -156,7 +156,7 @@ impl Memo { }, }); - self.verified_at.store(runtime.current_revision()); + self.verified_at.store(revision_now); } pub(super) fn mark_outputs_as_verified( diff --git a/src/function/specify.rs b/src/function/specify.rs index 9d2a3dd2..37e88082 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -64,7 +64,7 @@ where // - a result that is verified in the current revision, because it was set, which will use the set value // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) - let revision = db.zalsa().runtime().current_revision(); + let revision = db.zalsa().current_revision(); let mut revisions = QueryRevisions { changed_at: current_deps.changed_at, durability: current_deps.durability, @@ -101,7 +101,7 @@ where executor: DatabaseKeyIndex, key: Id, ) { - let runtime = db.zalsa().runtime(); + let zalsa = db.zalsa(); let memo = match self.memo_map.get(key) { Some(m) => m, @@ -119,6 +119,10 @@ where } let database_key_index = self.database_key_index(key); - memo.mark_as_verified(db.as_dyn_database(), runtime, database_key_index); + memo.mark_as_verified( + db.as_dyn_database(), + zalsa.current_revision(), + database_key_index, + ); } } diff --git a/src/function/sync.rs b/src/function/sync.rs index d59cc7e5..0f1d5178 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -28,7 +28,7 @@ impl SyncMap { local_state: &LocalState, database_key_index: DatabaseKeyIndex, ) -> Option> { - let runtime = db.zalsa().runtime(); + let runtime = db.zalsa().runtimex(); let thread_id = std::thread::current().id(); match self.sync_map.entry(database_key_index.key_index) { dashmap::mapref::entry::Entry::Vacant(entry) => { diff --git a/src/handle.rs b/src/handle.rs index 6b4dd167..28661888 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -78,7 +78,7 @@ impl Handle { /// same database! fn cancel_others(&mut self) { let zalsa = self.db().zalsa(); - zalsa.runtime().set_cancellation_flag(); + zalsa.set_cancellation_flag(); self.db().salsa_event(Event { thread_id: std::thread::current().id(), diff --git a/src/input.rs b/src/input.rs index ba1bd41a..6fa683e6 100644 --- a/src/input.rs +++ b/src/input.rs @@ -19,7 +19,6 @@ use crate::{ key::{DatabaseKeyIndex, DependencyIndex}, local_state::{self, QueryOrigin}, plumbing::{Jar, Stamp}, - runtime::Runtime, storage::IngredientIndex, Database, Durability, Id, Revision, }; @@ -121,18 +120,17 @@ impl IngredientImpl { /// * `setter`, function that modifies the fields tuple; should only modify the element for `field_index` pub fn set_field( &mut self, - runtime: &mut Runtime, + current_revision: Revision, id: C::Struct, field_index: usize, durability: Durability, setter: impl FnOnce(&mut C::Fields) -> R, ) -> R { - let revision = runtime.current_revision(); let id: Id = id.as_id(); let mut r = self.struct_map.update(id); let stamp = &mut r.stamps[field_index]; stamp.durability = durability; - stamp.changed_at = revision; + stamp.changed_at = current_revision; setter(&mut r.fields) } diff --git a/src/input/setter.rs b/src/input/setter.rs index a976aad0..e19d9c95 100644 --- a/src/input/setter.rs +++ b/src/input/setter.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use crate::input::{Configuration, IngredientImpl}; -use crate::{Durability, Runtime}; +use crate::{Durability, Revision}; /// Setter for a field of an input. pub trait Setter: Sized { @@ -12,7 +12,7 @@ pub trait Setter: Sized { #[must_use] pub struct SetterImpl<'setter, C: Configuration, S, F> { - runtime: &'setter mut Runtime, + current_revision: Revision, id: C::Struct, ingredient: &'setter mut IngredientImpl, durability: Durability, @@ -27,14 +27,14 @@ where S: FnOnce(&mut C::Fields, F) -> F, { pub fn new( - runtime: &'setter mut Runtime, + current_revision: Revision, id: C::Struct, field_index: usize, ingredient: &'setter mut IngredientImpl, setter: S, ) -> Self { SetterImpl { - runtime, + current_revision, id, field_index, ingredient, @@ -59,7 +59,7 @@ where fn to(self, value: F) -> F { let Self { - runtime, + current_revision, id, ingredient, durability, @@ -68,7 +68,7 @@ where phantom: _, } = self; - ingredient.set_field(runtime, id, field_index, durability, |tuple| { + ingredient.set_field(current_revision, id, field_index, durability, |tuple| { setter(tuple, value) }) } diff --git a/src/local_state.rs b/src/local_state.rs index 6e688e21..cbad4d73 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -13,7 +13,6 @@ use crate::Database; use crate::Event; use crate::EventKind; use crate::Revision; -use crate::Runtime; use std::cell::Cell; use std::cell::RefCell; use std::ptr::NonNull; @@ -321,21 +320,20 @@ impl LocalState { /// `salsa_event` is emitted when this method is called, so that should be /// used instead. pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) { - let runtime = db.zalsa().runtime(); let thread_id = std::thread::current().id(); db.salsa_event(Event { thread_id, kind: EventKind::WillCheckCancellation, }); - if runtime.load_cancellation_flag() { - self.unwind_cancelled(runtime); + let zalsa = db.zalsa(); + if zalsa.load_cancellation_flag() { + self.unwind_cancelled(zalsa.current_revision()); } } #[cold] - pub(crate) fn unwind_cancelled(&self, runtime: &Runtime) { - let current_revision = runtime.current_revision(); + pub(crate) fn unwind_cancelled(&self, current_revision: Revision) { self.report_untracked_read(current_revision); Cancelled::PendingWrite.throw(); } @@ -414,6 +412,10 @@ pub enum EdgeKind { Output, } +lazy_static::lazy_static! { + pub(crate) static ref EMPTY_DEPENDENCIES: Arc<[(EdgeKind, DependencyIndex)]> = Arc::new([]); +} + /// The edges between a memoized value and other queries in the dependency graph. /// These edges include both dependency edges /// e.g., when creating the memoized value for Q0 executed another function Q1) @@ -497,14 +499,14 @@ impl ActiveQueryGuard<'_> { /// which summarizes the other queries that were accessed during this /// query's execution. #[inline] - pub(crate) fn pop(self, runtime: &Runtime) -> QueryRevisions { + pub(crate) fn pop(self) -> QueryRevisions { // Extract accumulated inputs. let popped_query = self.complete(); // If this frame were a cycle participant, it would have unwound. assert!(popped_query.cycle.is_none()); - popped_query.revisions(runtime) + popped_query.revisions() } /// If the active query is registered as a cycle participant, remove and diff --git a/src/runtime.rs b/src/runtime.rs index 04b83645..3db4a6a6 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,13 +8,9 @@ use crossbeam::atomic::AtomicCell; use parking_lot::Mutex; use crate::{ - active_query::ActiveQuery, - cycle::CycleRecoveryStrategy, - durability::Durability, - key::{DatabaseKeyIndex, DependencyIndex}, - local_state::{EdgeKind, LocalState}, - revision::AtomicRevision, - Cancelled, Cycle, Database, Event, EventKind, Revision, + active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, + key::DatabaseKeyIndex, local_state::LocalState, revision::AtomicRevision, Cancelled, Cycle, + Database, Event, EventKind, Revision, }; use self::dependency_graph::DependencyGraph; @@ -25,9 +21,6 @@ pub struct Runtime { /// Stores the next id to use for a snapshotted runtime (starts at 1). next_id: AtomicUsize, - /// Vector we can clone - empty_dependencies: Arc<[(EdgeKind, DependencyIndex)]>, - /// Set to true when the current revision has been canceled. /// This is done when we an input is being changed. The flag /// is set back to false once the input has been changed. @@ -89,7 +82,6 @@ impl Default for Runtime { .map(|_| AtomicRevision::start()) .collect(), next_id: AtomicUsize::new(1), - empty_dependencies: None.into_iter().collect(), revision_canceled: Default::default(), dependency_graph: Default::default(), } @@ -112,10 +104,6 @@ impl Runtime { self.revisions[0].load() } - pub(crate) fn empty_dependencies(&self) -> Arc<[(EdgeKind, DependencyIndex)]> { - self.empty_dependencies.clone() - } - /// Reports that an input with durability `durability` changed. /// This will update the 'last changed at' values for every durability /// less than or equal to `durability` to the current revision. diff --git a/src/storage.rs b/src/storage.rs index 3139cbd6..0ecf132e 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -9,7 +9,7 @@ use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; use crate::views::{Views, ViewsOf}; -use crate::Database; +use crate::{Database, Durability, Revision}; pub fn views(db: &Db) -> &Views { db.zalsa().views() @@ -55,17 +55,36 @@ pub trait Zalsa { /// Gets an `&`-ref to an ingredient by index fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient; - /// Gets an `&mut`-ref to an ingredient by index; also returns the runtime for further use + /// Gets an `&mut`-ref to an ingredient by index. + /// + /// **Triggers a new revision.** Returns the `&mut` reference + /// along with the new revision index. fn lookup_ingredient_mut( &mut self, index: IngredientIndex, - ) -> (&mut dyn Ingredient, &mut Runtime); + ) -> (&mut dyn Ingredient, Revision); - /// Gets the salsa runtime - fn runtime(&self) -> &Runtime; + fn runtimex(&self) -> &Runtime; - /// Gets the salsa runtime - fn runtime_mut(&mut self) -> &mut Runtime; + /// Return the current revision + fn current_revision(&self) -> Revision; + + /// Increment revision counter. + /// + /// **Triggers a new revision.** + fn new_revision(&mut self) -> Revision; + + /// Return the time when an input of durability `durability` last changed + fn last_changed_revision(&self, durability: Durability) -> Revision; + + /// True if any threads have signalled for cancellation + fn load_cancellation_flag(&self) -> bool; + + /// Signal for cancellation, indicating current thread is trying to get unique access. + fn set_cancellation_flag(&self); + + /// Reports a (synthetic) tracked write to "some input of the given durability". + fn report_tracked_write(&mut self, durability: Durability); } impl Zalsa for Storage { @@ -119,19 +138,11 @@ impl Zalsa for Storage { &**self.ingredients_vec.get(index.as_usize()).unwrap() } - fn runtime(&self) -> &Runtime { - &self.runtime - } - - fn runtime_mut(&mut self) -> &mut Runtime { - &mut self.runtime - } - fn lookup_ingredient_mut( &mut self, index: IngredientIndex, - ) -> (&mut dyn Ingredient, &mut Runtime) { - self.runtime.new_revision(); + ) -> (&mut dyn Ingredient, Revision) { + let new_revision = self.runtime.new_revision(); for index in self.ingredients_requiring_reset.iter() { self.ingredients_vec @@ -142,9 +153,37 @@ impl Zalsa for Storage { ( &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap(), - &mut self.runtime, + new_revision, ) } + + fn current_revision(&self) -> Revision { + self.runtime.current_revision() + } + + fn load_cancellation_flag(&self) -> bool { + self.runtime.load_cancellation_flag() + } + + fn report_tracked_write(&mut self, durability: Durability) { + self.runtime.report_tracked_write(durability) + } + + fn runtimex(&self) -> &Runtime { + &self.runtime + } + + fn last_changed_revision(&self, durability: Durability) -> Revision { + self.runtime.last_changed_revision(durability) + } + + fn set_cancellation_flag(&self) { + self.runtime.set_cancellation_flag() + } + + fn new_revision(&mut self) -> Revision { + self.runtime.new_revision() + } } /// Nonce type representing the underlying database storage. diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 78bba1d8..0621782b 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -307,7 +307,7 @@ where let (id, new_id) = self.intern(entity_key); local_state.add_output(self.database_key_index(id).into()); - let current_revision = zalsa.runtime().current_revision(); + let current_revision = zalsa.current_revision(); if new_id { // This is a new tracked struct, so create an entry in the struct map. @@ -379,7 +379,7 @@ where /// /// If the struct has not been created in this revision. pub fn lookup_struct<'db>(&'db self, db: &'db dyn Database, id: Id) -> C::Struct<'db> { - let current_revision = db.zalsa().runtime().current_revision(); + let current_revision = db.zalsa().current_revision(); self.struct_map.get(current_revision, id) } @@ -458,9 +458,9 @@ where _executor: DatabaseKeyIndex, output_key: Option, ) { - let runtime = db.zalsa().runtime(); + let current_revision = db.zalsa().current_revision(); let output_key = output_key.unwrap(); - self.struct_map.validate(runtime, output_key); + self.struct_map.validate(current_revision, output_key); } fn remove_stale_output( diff --git a/src/tracked_struct/struct_map.rs b/src/tracked_struct/struct_map.rs index 4779a278..b8ea1578 100644 --- a/src/tracked_struct/struct_map.rs +++ b/src/tracked_struct/struct_map.rs @@ -6,7 +6,7 @@ use std::{ use crossbeam::queue::SegQueue; use dashmap::mapref::one::RefMut; -use crate::{alloc::Alloc, hash::FxDashMap, Id, Revision, Runtime}; +use crate::{alloc::Alloc, hash::FxDashMap, Id, Revision}; use super::{Configuration, KeyStruct, Value}; @@ -99,7 +99,7 @@ where unsafe { C::struct_from_raw(pointer) } } - pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: Id) { + pub fn validate<'db>(&'db self, current_revision: Revision, id: Id) { let mut data = self.map.get_mut(&id).unwrap(); // UNSAFE: We never permit `&`-access in the current revision until data.created_at @@ -107,7 +107,6 @@ where let data = unsafe { data.as_mut() }; // Never update a struct twice in the same revision. - let current_revision = runtime.current_revision(); assert!(data.created_at < current_revision); data.created_at = current_revision; } diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index a10eb31f..7a6b7b42 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -48,7 +48,7 @@ where /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { local_state::attach(db, |local_state| { - let current_revision = db.zalsa().runtime().current_revision(); + let current_revision = db.zalsa().current_revision(); let data = self.struct_map.get(current_revision, id); let data = C::deref_struct(data); let changed_at = data.revisions[self.field_index]; @@ -85,7 +85,7 @@ where input: Option, revision: crate::Revision, ) -> bool { - let current_revision = db.zalsa().runtime().current_revision(); + let current_revision = db.zalsa().current_revision(); let id = input.unwrap(); let data = self.struct_map.get(current_revision, id); let data = C::deref_struct(data); From 64556e9d286c26aed57b67af18f988bc632255c6 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 11:47:58 +0000 Subject: [PATCH 05/29] make event generation lazy Creating events if nobody is listening has always bugged me. --- examples/calc/db.rs | 3 ++- examples/lazy-input/main.rs | 3 ++- src/accumulator.rs | 2 +- src/database.rs | 8 ++++++-- src/function.rs | 2 +- src/function/diff_outputs.rs | 8 +++++--- src/function/execute.rs | 2 +- src/function/memo.rs | 2 +- src/handle.rs | 2 +- src/local_state.rs | 2 +- src/runtime.rs | 2 +- src/tracked_struct.rs | 2 +- tests/accumulate-from-tracked-fn.rs | 2 +- tests/accumulate-reuse-workaround.rs | 2 +- tests/accumulate-reuse.rs | 2 +- tests/accumulate.rs | 2 +- tests/deletion-cascade.rs | 3 ++- tests/deletion.rs | 3 ++- tests/parallel/setup.rs | 3 ++- tests/preverify-struct-with-leaked-data.rs | 3 ++- tests/tracked-struct-value-field-bad-eq.rs | 3 ++- 21 files changed, 37 insertions(+), 24 deletions(-) diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 348adceb..35d47341 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -36,7 +36,8 @@ impl Database { // ANCHOR: db_impl #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); eprintln!("Event: {event:?}"); // Log interesting events, if logging is enabled if let Some(logs) = &self.logs { diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index e37908b2..3c0a6b4b 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -124,8 +124,9 @@ impl Db for Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { // don't log boring events + let event = event(); if let salsa::EventKind::WillExecute { .. } = event.kind { self.logs.lock().unwrap().push(format!("{:?}", event)); } diff --git a/src/accumulator.rs b/src/accumulator.rs index 1e9c368d..dc114bb2 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -177,7 +177,7 @@ impl Ingredient for IngredientImpl { ) { assert!(stale_output_key.is_none()); if self.map.remove(&executor).is_some() { - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: EventKind::DidDiscardAccumulated { executor_key: executor, diff --git a/src/database.rs b/src/database.rs index ba39cf66..23606cfa 100644 --- a/src/database.rs +++ b/src/database.rs @@ -8,8 +8,12 @@ pub trait Database: ZalsaDatabase + AsDynDatabase { /// /// By default, the event is logged at level debug using /// the standard `log` facade. - fn salsa_event(&self, event: Event) { - tracing::debug!("salsa_event: {:?}", event) + /// + /// # Parameters + /// + /// * `event`, a fn that, if called, will create the event that occurred + fn salsa_event(&self, event: &dyn Fn() -> Event) { + tracing::debug!("salsa_event: {:?}", event()) } /// A "synthetic write" causes the system to act *as though* some diff --git a/src/function.rs b/src/function.rs index 3c95df59..72313cab 100644 --- a/src/function.rs +++ b/src/function.rs @@ -264,7 +264,7 @@ where if let Some(origin) = self.delete_memo(id) { let key = self.database_key_index(id); - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: EventKind::DidDiscard { key }, }); diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 28617040..96df56f7 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -1,6 +1,6 @@ use crate::{ hash::FxHashSet, key::DependencyIndex, local_state::QueryRevisions, AsDynDatabase as _, - Database, DatabaseKeyIndex, Event, EventKind, + DatabaseKeyIndex, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -38,7 +38,9 @@ where } fn report_stale_output(db: &C::DbView, key: DatabaseKeyIndex, output: DependencyIndex) { - db.salsa_event(Event { + let db = db.as_dyn_database(); + + db.salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: EventKind::WillDiscardStaleOutput { execute_key: key, @@ -46,6 +48,6 @@ where }, }); - output.remove_stale_output(db.as_dyn_database(), key); + output.remove_stale_output(db, key); } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 7cbbf62d..9ed1699e 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -32,7 +32,7 @@ where tracing::info!("{:?}: executing query", database_key_index); - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: EventKind::WillExecute { database_key: database_key_index, diff --git a/src/function/memo.rs b/src/function/memo.rs index ef26a014..4413d71a 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -149,7 +149,7 @@ impl Memo { revision_now: Revision, database_key_index: DatabaseKeyIndex, ) { - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: EventKind::DidValidateMemoizedValue { database_key: database_key_index, diff --git a/src/handle.rs b/src/handle.rs index 28661888..e3c2ecd9 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -80,7 +80,7 @@ impl Handle { let zalsa = self.db().zalsa(); zalsa.set_cancellation_flag(); - self.db().salsa_event(Event { + self.db().salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: EventKind::DidSetCancellationFlag, diff --git a/src/local_state.rs b/src/local_state.rs index cbad4d73..b9a16c4c 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -321,7 +321,7 @@ impl LocalState { /// used instead. pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) { let thread_id = std::thread::current().id(); - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id, kind: EventKind::WillCheckCancellation, diff --git a/src/runtime.rs b/src/runtime.rs index 3db4a6a6..3ac4ae8e 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -187,7 +187,7 @@ impl Runtime { assert!(!dg.depends_on(other_id, thread_id)); } - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id, kind: EventKind::WillBlockOn { other_thread_id: other_id, diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 0621782b..45136529 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -394,7 +394,7 @@ where /// unspecified results (but not UB). See [`InternedIngredient::delete_index`] for more /// discussion and important considerations. pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id) { - db.salsa_event(Event { + db.salsa_event(&|| Event { thread_id: std::thread::current().id(), kind: crate::EventKind::DidDiscard { key: self.database_key_index(id), diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index a60f08b8..039b1a9b 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -52,7 +52,7 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, _event: salsa::Event) {} + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} } #[salsa::db] diff --git a/tests/accumulate-reuse-workaround.rs b/tests/accumulate-reuse-workaround.rs index 8ba48f30..783ca0cc 100644 --- a/tests/accumulate-reuse-workaround.rs +++ b/tests/accumulate-reuse-workaround.rs @@ -59,7 +59,7 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, _event: salsa::Event) {} + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} } #[salsa::db] diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index 37bfacaa..2075a8e7 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -50,7 +50,7 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, _event: salsa::Event) {} + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} } #[salsa::db] diff --git a/tests/accumulate.rs b/tests/accumulate.rs index 3b6ce192..099128bd 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -65,7 +65,7 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, _event: salsa::Event) {} + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} } #[salsa::db] diff --git a/tests/deletion-cascade.rs b/tests/deletion-cascade.rs index 1d65602f..8385d048 100644 --- a/tests/deletion-cascade.rs +++ b/tests/deletion-cascade.rs @@ -60,7 +60,8 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); match event.kind { salsa::EventKind::WillDiscardStaleOutput { .. } | salsa::EventKind::DidDiscard { .. } => { diff --git a/tests/deletion.rs b/tests/deletion.rs index 78c0a58d..cdcc14bf 100644 --- a/tests/deletion.rs +++ b/tests/deletion.rs @@ -54,7 +54,8 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); match event.kind { salsa::EventKind::WillDiscardStaleOutput { .. } | salsa::EventKind::DidDiscard { .. } => { diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 94c31e50..c7c00ee8 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -37,7 +37,8 @@ pub(crate) struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); match event.kind { salsa::EventKind::WillBlockOn { .. } => { self.signal(self.knobs().signal_on_will_block.load()); diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 0f52ea0f..2e311c6f 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -25,7 +25,8 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); self.push_log(format!("{event:?}")); } } diff --git a/tests/tracked-struct-value-field-bad-eq.rs b/tests/tracked-struct-value-field-bad-eq.rs index 3759c912..a9d50c6f 100644 --- a/tests/tracked-struct-value-field-bad-eq.rs +++ b/tests/tracked-struct-value-field-bad-eq.rs @@ -61,7 +61,8 @@ struct Database { #[salsa::db] impl salsa::Database for Database { - fn salsa_event(&self, event: salsa::Event) { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); match event.kind { salsa::EventKind::WillExecute { .. } | salsa::EventKind::DidValidateMemoizedValue { .. } => { From daaa78056abafafc7fcde91c633cbc3a8ae35f48 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 12:29:41 +0000 Subject: [PATCH 06/29] switch to new database design Under this design, *all* databases are a `DatabaseImpl`, where the `U` implements `UserData` (you can use `()` if there is none). Code would default to `&dyn salsa::Database` but if you want to give access to the userdata, you can define a custom database trait `MyDatabase: salsa::Databse` so long as you * annotate `MyDatabase` trait definition of impls of `MyDatabase` with `#[salsa::db]` * implement `MyDatabase` for `DatabaseImpl` where `U` is your userdata (this could be a blanket impl, if you don't know the precise userdata type). The `tests/common/mod.rs` shows the pattern. --- benches/incremental.rs | 2 +- components/salsa-macros/src/db.rs | 52 +-------- examples/calc/db.rs | 43 ++++--- examples/calc/main.rs | 3 +- examples/calc/parser.rs | 6 +- examples/calc/type_check.rs | 9 +- examples/lazy-input/main.rs | 35 +++--- src/database.rs | 109 ++++++++++++++++-- src/function/accumulated.rs | 4 +- src/function/execute.rs | 3 +- src/function/fetch.rs | 3 +- src/function/maybe_changed_after.rs | 4 +- src/function/specify.rs | 1 - src/lib.rs | 20 +--- src/storage.rs | 46 ++++---- tests/accumulate-chain.rs | 4 +- tests/accumulate-custom-clone.rs | 2 +- tests/accumulate-custom-debug.rs | 2 +- tests/accumulate-dag.rs | 2 +- tests/accumulate-execution-order.rs | 2 +- tests/accumulate-from-tracked-fn.rs | 31 +---- tests/accumulate-no-duplicates.rs | 2 +- tests/accumulate-reuse-workaround.rs | 36 +----- tests/accumulate-reuse.rs | 32 +---- tests/accumulate.rs | 40 ++----- tests/common/mod.rs | 96 ++++++++++++++- .../get-set-on-private-input-field.rs | 2 +- ...of-tracked-structs-from-older-revisions.rs | 2 +- tests/compile-fail/span-input-setter.rs | 2 +- tests/compile-fail/span-tracked-getter.rs | 2 +- tests/compile-fail/span-tracked-getter.stderr | 2 +- tests/cycles.rs | 41 +++---- tests/debug.rs | 17 +-- tests/deletion-cascade.rs | 47 ++------ tests/deletion-drops.rs | 2 +- tests/deletion.rs | 43 +------ tests/elided-lifetime-in-tracked-fn.rs | 32 +---- ...truct_changes_but_fn_depends_on_field_y.rs | 34 +----- ...input_changes_but_fn_depends_on_field_y.rs | 31 +---- tests/hello_world.rs | 32 +---- tests/interned-struct-with-lifetime.rs | 30 +---- tests/is_send_sync.rs | 20 +--- tests/lru.rs | 40 ++----- tests/mutate_in_place.rs | 21 +--- tests/override_new_get_set.rs | 14 +-- ...ng-tracked-struct-outside-of-tracked-fn.rs | 11 +- tests/parallel/parallel_cancellation.rs | 15 +-- tests/parallel/parallel_cycle_all_recover.rs | 27 ++--- tests/parallel/parallel_cycle_mid_recover.rs | 27 ++--- tests/parallel/parallel_cycle_none_recover.rs | 15 +-- tests/parallel/parallel_cycle_one_recover.rs | 23 ++-- tests/parallel/setup.rs | 35 +++--- tests/preverify-struct-with-leaked-data.rs | 35 +----- tests/singleton.rs | 24 +--- ...the-key-is-created-in-the-current-query.rs | 26 ++--- tests/synthetic_write.rs | 42 +------ tests/tracked-struct-id-field-bad-eq.rs | 16 +-- tests/tracked-struct-id-field-bad-hash.rs | 11 +- tests/tracked-struct-unchanged-in-new-rev.rs | 11 +- tests/tracked-struct-value-field-bad-eq.rs | 39 +------ tests/tracked-struct-value-field-not-eq.rs | 15 +-- tests/tracked_fn_constant.rs | 11 +- tests/tracked_fn_no_eq.rs | 33 +----- tests/tracked_fn_on_input.rs | 11 +- tests/tracked_fn_on_interned.rs | 11 +- tests/tracked_fn_on_tracked.rs | 11 +- tests/tracked_fn_on_tracked_specify.rs | 13 +-- tests/tracked_fn_read_own_entity.rs | 32 +---- tests/tracked_fn_read_own_specify.rs | 32 +---- tests/tracked_fn_return_ref.rs | 13 +-- tests/tracked_method.rs | 11 +- tests/tracked_method_inherent_return_ref.rs | 13 +-- tests/tracked_method_on_tracked_struct.rs | 4 +- tests/tracked_method_trait_return_ref.rs | 13 +-- tests/tracked_struct_db1_lt.rs | 9 -- tests/tracked_with_intern.rs | 9 -- tests/tracked_with_struct_db.rs | 15 +-- 77 files changed, 491 insertions(+), 1125 deletions(-) diff --git a/benches/incremental.rs b/benches/incremental.rs index 101e6d01..5e5aa5f4 100644 --- a/benches/incremental.rs +++ b/benches/incremental.rs @@ -26,7 +26,7 @@ fn many_tracked_structs(criterion: &mut Criterion) { criterion.bench_function("many_tracked_structs", |b| { b.iter_batched_ref( || { - let db = salsa::default_database(); + let db = salsa::DatabaseImpl::new(); let input = Input::new(&db, 1_000); let input2 = Input::new(&db, 1); diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index b7057613..e70ad745 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -32,13 +32,6 @@ struct DbMacro { impl DbMacro { fn try_db(self, input: syn::Item) -> syn::Result { match input { - syn::Item::Struct(input) => { - let has_storage_impl = self.zalsa_database_impl(&input)?; - Ok(quote! { - #has_storage_impl - #input - }) - } syn::Item::Trait(mut input) => { self.add_salsa_view_method(&mut input)?; Ok(quote! { @@ -53,54 +46,11 @@ impl DbMacro { } _ => Err(syn::Error::new_spanned( input, - "`db` must be applied to a struct, trait, or impl", + "`db` must be applied to a trait or impl", )), } } - fn find_storage_field(&self, input: &syn::ItemStruct) -> syn::Result { - let storage = "storage"; - for field in input.fields.iter() { - if let Some(i) = &field.ident { - if i == storage { - return Ok(i.clone()); - } - } else { - return Err(syn::Error::new_spanned( - field, - "database struct must be a braced struct (`{}`) with a field named `storage`", - )); - } - } - - Err(syn::Error::new_spanned( - &input.ident, - "database struct must be a braced struct (`{}`) with a field named `storage`", - )) - } - - fn zalsa_database_impl(&self, input: &syn::ItemStruct) -> syn::Result { - let storage = self.find_storage_field(input)?; - let db = &input.ident; - let zalsa = self.hygiene.ident("zalsa"); - - Ok(quote! { - const _: () = { - use salsa::plumbing as #zalsa; - - unsafe impl #zalsa::ZalsaDatabase for #db { - fn zalsa(&self) -> &dyn #zalsa::Zalsa { - &self.#storage - } - - fn zalsa_mut(&mut self) -> &mut dyn #zalsa::Zalsa { - &mut self.#storage - } - } - }; - }) - } - fn add_salsa_view_method(&self, input: &mut syn::ItemTrait) -> syn::Result<()> { input.items.push(parse_quote! { #[doc(hidden)] diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 35d47341..5f1d98f7 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -1,49 +1,48 @@ use std::sync::{Arc, Mutex}; +use salsa::UserData; + +pub type CalcDatabaseImpl = salsa::DatabaseImpl; + // ANCHOR: db_struct #[derive(Default)] -#[salsa::db] -pub(crate) struct Database { - storage: salsa::Storage, - +pub struct Calc { // The logs are only used for testing and demonstrating reuse: - // - logs: Option>>>, + logs: Arc>>>, } // ANCHOR_END: db_struct -impl Database { +impl Calc { /// Enable logging of each salsa event. #[cfg(test)] - pub fn enable_logging(self) -> Self { - assert!(self.logs.is_none()); - Self { - storage: self.storage, - logs: Some(Default::default()), + pub fn enable_logging(&self) { + let mut logs = self.logs.lock().unwrap(); + if logs.is_none() { + *logs = Some(vec![]); } } #[cfg(test)] - pub fn take_logs(&mut self) -> Vec { - if let Some(logs) = &self.logs { - std::mem::take(&mut *logs.lock().unwrap()) + pub fn take_logs(&self) -> Vec { + let mut logs = self.logs.lock().unwrap(); + if let Some(logs) = &mut *logs { + std::mem::take(logs) } else { - panic!("logs not enabled"); + vec![] } } } // ANCHOR: db_impl -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { +impl UserData for Calc { + fn salsa_event(db: &CalcDatabaseImpl, event: &dyn Fn() -> salsa::Event) { let event = event(); eprintln!("Event: {event:?}"); // Log interesting events, if logging is enabled - if let Some(logs) = &self.logs { - // don't log boring events + if let Some(logs) = &mut *db.logs.lock().unwrap() { + // only log interesting events if let salsa::EventKind::WillExecute { .. } = event.kind { - logs.lock().unwrap().push(format!("Event: {event:?}")); + logs.push(format!("Event: {event:?}")); } } } diff --git a/examples/calc/main.rs b/examples/calc/main.rs index 332b20e2..616dede6 100644 --- a/examples/calc/main.rs +++ b/examples/calc/main.rs @@ -1,3 +1,4 @@ +use db::CalcDatabaseImpl; use ir::{Diagnostic, SourceProgram}; use salsa::Database as Db; @@ -8,7 +9,7 @@ mod parser; mod type_check; pub fn main() { - let db = db::Database::default(); + let db: CalcDatabaseImpl = Default::default(); let source_program = SourceProgram::new(&db, String::new()); compile::compile(&db, source_program); let diagnostics = compile::compile::accumulated::(&db, source_program); diff --git a/examples/calc/parser.rs b/examples/calc/parser.rs index 05c1c4b3..40077361 100644 --- a/examples/calc/parser.rs +++ b/examples/calc/parser.rs @@ -351,9 +351,11 @@ impl<'db> Parser<'_, 'db> { /// Returns the statements and the diagnostics generated. #[cfg(test)] fn parse_string(source_text: &str) -> String { - use salsa::Database as _; + use salsa::Database; - crate::db::Database::default().attach(|db| { + use crate::db::CalcDatabaseImpl; + + CalcDatabaseImpl::default().attach(|db| { // Create the source program let source_program = SourceProgram::new(db, source_text.to_string()); diff --git a/examples/calc/type_check.rs b/examples/calc/type_check.rs index 4fe22605..de449ad3 100644 --- a/examples/calc/type_check.rs +++ b/examples/calc/type_check.rs @@ -6,8 +6,6 @@ use derive_new::new; use expect_test::expect; use salsa::Accumulator; #[cfg(test)] -use salsa::Database as _; -#[cfg(test)] use test_log::test; // ANCHOR: parse_statements @@ -100,12 +98,13 @@ fn check_string( expected_diagnostics: expect_test::Expect, edits: &[(&str, expect_test::Expect, expect_test::Expect)], ) { - use salsa::Setter; + use salsa::{Database, Setter}; - use crate::{db::Database, ir::SourceProgram, parser::parse_statements}; + use crate::{db::CalcDatabaseImpl, ir::SourceProgram, parser::parse_statements}; // Create the database - let mut db = Database::default().enable_logging(); + let mut db = CalcDatabaseImpl::default(); + db.enable_logging(); // Create the source program let source_program = SourceProgram::new(&db, source_text.to_string()); diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index 3c0a6b4b..3891367b 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -8,13 +8,13 @@ use notify_debouncer_mini::{ notify::{RecommendedWatcher, RecursiveMode}, DebounceEventResult, Debouncer, }; -use salsa::{Accumulator, Setter}; +use salsa::{Accumulator, DatabaseImpl, Setter, UserData}; // ANCHOR: main fn main() -> Result<()> { // Create the channel to receive file change events. let (tx, rx) = unbounded(); - let mut db = Database::new(tx); + let mut db = DatabaseImpl::with(LazyInput::new(tx)); let initial_file_path = std::env::args_os() .nth(1) @@ -74,19 +74,15 @@ trait Db: salsa::Database { fn input(&self, path: PathBuf) -> Result; } -#[salsa::db] -struct Database { - storage: salsa::Storage, +struct LazyInput { logs: Mutex>, files: DashMap, file_watcher: Mutex>, } -impl Database { +impl LazyInput { fn new(tx: Sender) -> Self { - let storage = Default::default(); Self { - storage, logs: Default::default(), files: DashMap::new(), file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()), @@ -94,8 +90,18 @@ impl Database { } } +impl UserData for LazyInput { + fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { + // don't log boring events + let event = event(); + if let salsa::EventKind::WillExecute { .. } = event.kind { + db.logs.lock().unwrap().push(format!("{:?}", event)); + } + } +} + #[salsa::db] -impl Db for Database { +impl Db for DatabaseImpl { fn input(&self, path: PathBuf) -> Result { let path = path .canonicalize() @@ -122,17 +128,6 @@ impl Db for Database { } // ANCHOR_END: db -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - // don't log boring events - let event = event(); - if let salsa::EventKind::WillExecute { .. } = event.kind { - self.logs.lock().unwrap().push(format!("{:?}", event)); - } - } -} - #[salsa::accumulator] struct Diagnostic(String); diff --git a/src/database.rs b/src/database.rs index 23606cfa..52f74e65 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,20 +1,24 @@ -use crate::{local_state, storage::ZalsaDatabase, Durability, Event, Revision}; +use std::{any::Any, panic::RefUnwindSafe}; +use crate::{self as salsa, local_state, storage::Zalsa, Durability, Event, Revision, Storage}; + +/// The trait implemented by all Salsa databases. +/// You can create your own subtraits of this trait using the `#[salsa::db]` procedural macro. +/// +/// # Safety conditions +/// +/// This trait can only safely be implemented by Salsa's [`DatabaseImpl`][] type. +/// FIXME: Document better the unsafety conditions we guarantee. #[salsa_macros::db] -pub trait Database: ZalsaDatabase + AsDynDatabase { - /// This function is invoked at key points in the salsa - /// runtime. It permits the database to be customized and to - /// inject logging or other custom behavior. - /// - /// By default, the event is logged at level debug using - /// the standard `log` facade. +pub unsafe trait Database: AsDynDatabase + Any { + /// This function is invoked by the salsa runtime at various points during execution. + /// You can customize what happens by implementing the [`UserData`][] trait. + /// By default, the event is logged at level debug using tracing facade. /// /// # Parameters /// /// * `event`, a fn that, if called, will create the event that occurred - fn salsa_event(&self, event: &dyn Fn() -> Event) { - tracing::debug!("salsa_event: {:?}", event()) - } + fn salsa_event(&self, event: &dyn Fn() -> Event); /// A "synthetic write" causes the system to act *as though* some /// input of durability `durability` has changed. This is mostly @@ -48,6 +52,13 @@ pub trait Database: ZalsaDatabase + AsDynDatabase { { local_state::attach(self, |_state| op(self)) } + + /// Plumbing methods. + #[doc(hidden)] + fn zalsa(&self) -> &dyn Zalsa; + + #[doc(hidden)] + fn zalsa_mut(&mut self) -> &mut dyn Zalsa; } /// Upcast to a `dyn Database`. @@ -83,3 +94,79 @@ impl dyn Database { self.zalsa().views().try_view_as(self).unwrap() } } + +/// Concrete implementation of the [`Database`][] trait. +/// Takes an optional type parameter `U` that allows you to thread your own data. +pub struct DatabaseImpl { + storage: Storage, +} + +impl Default for DatabaseImpl { + fn default() -> Self { + Self::with(U::default()) + } +} + +impl DatabaseImpl<()> { + /// Create a new database with the given user data. + /// + /// You can also use the [`Default`][] trait if your userdata implements it. + pub fn new() -> Self { + Self { + storage: Storage::with(()), + } + } +} + +impl DatabaseImpl { + /// Create a new database with the given user data. + /// + /// You can also use the [`Default`][] trait if your userdata implements it. + pub fn with(u: U) -> Self { + Self { + storage: Storage::with(u), + } + } +} + +impl std::ops::Deref for DatabaseImpl { + type Target = U; + + fn deref(&self) -> &U { + &self.storage.user_data() + } +} + +impl RefUnwindSafe for DatabaseImpl {} + +#[salsa_macros::db] +unsafe impl Database for DatabaseImpl { + fn zalsa(&self) -> &dyn Zalsa { + &self.storage + } + + fn zalsa_mut(&mut self) -> &mut dyn Zalsa { + &mut self.storage + } + + // Report a salsa event. + fn salsa_event(&self, event: &dyn Fn() -> Event) { + U::salsa_event(self, event) + } +} + +pub trait UserData: Any + Sized { + /// Callback invoked by the [`Database`][] at key points during salsa execution. + /// By overriding this method, you can inject logging or other custom behavior. + /// + /// By default, the event is logged at level debug using the `tracing` crate. + /// + /// # Parameters + /// + /// * `event` a fn that, if called, will return the event that occurred + fn salsa_event(_db: &DatabaseImpl, event: &dyn Fn() -> Event) { + tracing::debug!("salsa_event: {:?}", event()) + } +} + +impl UserData for () {} diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index d1fc003d..98d11eaa 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,6 +1,4 @@ -use crate::{ - accumulator, hash::FxHashSet, local_state, storage::ZalsaDatabase as _, DatabaseKeyIndex, Id, -}; +use crate::{accumulator, hash::FxHashSet, local_state, Database, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; diff --git a/src/function/execute.rs b/src/function/execute.rs index 9ed1699e..cba601d0 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,8 +1,7 @@ use std::sync::Arc; use crate::{ - local_state::ActiveQueryGuard, runtime::StampedValue, storage::ZalsaDatabase, Cycle, Database, - Event, EventKind, + local_state::ActiveQueryGuard, runtime::StampedValue, Cycle, Database, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 2474acbe..f204145f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -3,8 +3,7 @@ use arc_swap::Guard; use crate::{ local_state::{self, LocalState}, runtime::StampedValue, - storage::ZalsaDatabase as _, - AsDynDatabase as _, Id, + AsDynDatabase as _, Database as _, Id, }; use super::{Configuration, IngredientImpl}; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index a9426cf7..15a677d5 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -4,8 +4,8 @@ use crate::{ key::DatabaseKeyIndex, local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, - storage::{Zalsa, ZalsaDatabase as _}, - AsDynDatabase as _, Id, Revision, + storage::Zalsa, + AsDynDatabase as _, Database, Id, Revision, }; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/specify.rs b/src/function/specify.rs index 37e88082..d8d5dea8 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -2,7 +2,6 @@ use crossbeam::atomic::AtomicCell; use crate::{ local_state::{self, QueryOrigin, QueryRevisions}, - storage::ZalsaDatabase, tracked_struct::TrackedStructInDb, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; diff --git a/src/lib.rs b/src/lib.rs index 395938dd..7b38a291 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,8 @@ pub use self::cancelled::Cancelled; pub use self::cycle::Cycle; pub use self::database::AsDynDatabase; pub use self::database::Database; +pub use self::database::DatabaseImpl; +pub use self::database::UserData; pub use self::durability::Durability; pub use self::event::Event; pub use self::event::EventKind; @@ -50,23 +52,9 @@ pub use salsa_macros::interned; pub use salsa_macros::tracked; pub use salsa_macros::Update; -pub fn default_database() -> impl Database { - use crate as salsa; - - #[crate::db] - #[derive(Default)] - struct DefaultDatabase { - storage: Storage, - } - - #[crate::db] - impl Database for DefaultDatabase {} - - DefaultDatabase::default() -} - pub mod prelude { pub use crate::Accumulator; + pub use crate::Database; pub use crate::Setter; } @@ -82,6 +70,7 @@ pub mod plumbing { pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; + pub use crate::database::UserData; pub use crate::function::should_backdate_value; pub use crate::id::AsId; pub use crate::id::FromId; @@ -102,7 +91,6 @@ pub mod plumbing { pub use crate::storage::IngredientIndex; pub use crate::storage::Storage; pub use crate::storage::Zalsa; - pub use crate::storage::ZalsaDatabase; pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::always_update; pub use crate::update::helper::Dispatch as UpdateDispatch; diff --git a/src/storage.rs b/src/storage.rs index 0ecf132e..f05c22f8 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,10 +1,11 @@ -use std::any::{Any, TypeId}; +use std::any::TypeId; use orx_concurrent_vec::ConcurrentVec; use parking_lot::Mutex; use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; +use crate::database::{DatabaseImpl, UserData}; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; @@ -15,22 +16,6 @@ pub fn views(db: &Db) -> &Views { db.zalsa().views() } -/// Salsa database methods whose implementation is generated by -/// the `#[salsa::database]` procedural macro. -/// -/// # Safety -/// -/// This trait is meant to be implemented by our procedural macro. -/// We need to document any non-obvious conditions that it satisfies. -pub unsafe trait ZalsaDatabase: Any { - /// Plumbing methods. - #[doc(hidden)] - fn zalsa(&self) -> &dyn Zalsa; - - #[doc(hidden)] - fn zalsa_mut(&mut self) -> &mut dyn Zalsa; -} - /// The "plumbing interface" to the Salsa database. /// /// **NOT SEMVER STABLE.** @@ -87,9 +72,9 @@ pub trait Zalsa { fn report_tracked_write(&mut self, durability: Durability); } -impl Zalsa for Storage { +impl Zalsa for Storage { fn views(&self) -> &Views { - &self.upcasts + &self.views_of } fn nonce(&self) -> Nonce { @@ -227,8 +212,10 @@ impl IngredientIndex { /// The "storage" struct stores all the data for the jars. /// It is shared between the main database and any active snapshots. -pub struct Storage { - upcasts: ViewsOf, +pub struct Storage { + user_data: U, + + views_of: ViewsOf>, nonce: Nonce, @@ -254,19 +241,30 @@ pub struct Storage { } // ANCHOR: default -impl Default for Storage { +impl Default for Storage { fn default() -> Self { + Self::with(Default::default()) + } +} +// ANCHOR_END: default + +impl Storage { + pub(crate) fn with(user_data: U) -> Self { Self { - upcasts: Default::default(), + views_of: Default::default(), nonce: NONCE.nonce(), jar_map: Default::default(), ingredients_vec: Default::default(), ingredients_requiring_reset: Default::default(), runtime: Runtime::default(), + user_data, } } + + pub(crate) fn user_data(&self) -> &U { + &self.user_data + } } -// ANCHOR_END: default /// Caches a pointer to an ingredient in a database. /// Optimized for the case of a single database. diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index bf19bc29..b0d79bd4 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -4,7 +4,7 @@ mod common; use expect_test::expect; -use salsa::{Accumulator, Database}; +use salsa::{Accumulator, Database, DatabaseImpl}; use test_log::test; #[salsa::accumulator] @@ -40,7 +40,7 @@ fn push_d_logs(db: &dyn Database) { #[test] fn accumulate_chain() { - salsa::default_database().attach(|db| { + DatabaseImpl::new().attach(|db| { let logs = push_logs::accumulated::(db); // Check that we get all the logs. expect![[r#" diff --git a/tests/accumulate-custom-clone.rs b/tests/accumulate-custom-clone.rs index deb27791..81612b31 100644 --- a/tests/accumulate-custom-clone.rs +++ b/tests/accumulate-custom-clone.rs @@ -27,7 +27,7 @@ fn push_logs(db: &dyn salsa::Database, input: MyInput) { #[test] fn accumulate_custom_clone() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 2); let logs = push_logs::accumulated::(db, input); expect![[r##" diff --git a/tests/accumulate-custom-debug.rs b/tests/accumulate-custom-debug.rs index 59053803..71a4ba86 100644 --- a/tests/accumulate-custom-debug.rs +++ b/tests/accumulate-custom-debug.rs @@ -27,7 +27,7 @@ fn push_logs(db: &dyn salsa::Database, input: MyInput) { #[test] fn accumulate_custom_debug() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 2); let logs = push_logs::accumulated::(db, input); expect![[r##" diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs index d0c0cfeb..e23050ba 100644 --- a/tests/accumulate-dag.rs +++ b/tests/accumulate-dag.rs @@ -39,7 +39,7 @@ fn push_b_logs(db: &dyn Database, input: MyInput) { #[test] fn accumulate_a_called_twice() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 2, 3); let logs = push_logs::accumulated::(db, input); // Check that we don't see logs from `a` appearing twice in the input. diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs index c1fa0481..edb82e48 100644 --- a/tests/accumulate-execution-order.rs +++ b/tests/accumulate-execution-order.rs @@ -41,7 +41,7 @@ fn push_d_logs(db: &dyn Database) { #[test] fn accumulate_execution_order() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let logs = push_logs::accumulated::(db); // Check that we get logs in execution order expect![[r#" diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index 039b1a9b..33d7bd3f 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -2,16 +2,10 @@ //! Then mutate the values so that the tracked function re-executes. //! Check that we accumulate the appropriate, new values. -mod common; -use common::{HasLogger, Logger}; - use expect_test::expect; use salsa::{Accumulator, Setter}; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct List { value: u32, @@ -23,7 +17,7 @@ struct List { struct Integers(u32); #[salsa::tracked] -fn compute(db: &dyn Db, input: List) { +fn compute(db: &dyn salsa::Database, input: List) { eprintln!( "{:?}(value={:?}, next={:?})", input, @@ -43,30 +37,9 @@ fn compute(db: &dyn Db, input: List) { eprintln!("pushed result {:?}", result); } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn test1() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let l0 = List::new(&db, 1, None); let l1 = List::new(&db, 10, Some(l0)); diff --git a/tests/accumulate-no-duplicates.rs b/tests/accumulate-no-duplicates.rs index 10d47baa..faf8c03a 100644 --- a/tests/accumulate-no-duplicates.rs +++ b/tests/accumulate-no-duplicates.rs @@ -73,7 +73,7 @@ fn push_e_logs(db: &dyn Database) { #[test] fn accumulate_no_duplicates() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let logs = push_logs::accumulated::(db); // Test that there aren't duplicate B logs. // Note that log A appears twice, because they both come diff --git a/tests/accumulate-reuse-workaround.rs b/tests/accumulate-reuse-workaround.rs index 783ca0cc..f43c098e 100644 --- a/tests/accumulate-reuse-workaround.rs +++ b/tests/accumulate-reuse-workaround.rs @@ -3,15 +3,12 @@ //! reuse. mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; -use salsa::{Accumulator, Setter}; +use salsa::{Accumulator, DatabaseImpl, Setter}; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct List { value: u32, @@ -23,7 +20,7 @@ struct List { struct Integers(u32); #[salsa::tracked] -fn compute(db: &dyn Db, input: List) -> u32 { +fn compute(db: &dyn LogDatabase, input: List) -> u32 { db.push_log(format!("compute({:?})", input,)); // always pushes 0 @@ -42,38 +39,17 @@ fn compute(db: &dyn Db, input: List) -> u32 { } #[salsa::tracked(return_ref)] -fn accumulated(db: &dyn Db, input: List) -> Vec { - db.push_log(format!("accumulated({:?})", input,)); +fn accumulated(db: &dyn LogDatabase, input: List) -> Vec { + db.push_log(format!("accumulated({:?})", input)); compute::accumulated::(db, input) .into_iter() .map(|a| a.0) .collect() } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn test1() { - let mut db = Database::default(); + let mut db: DatabaseImpl = DatabaseImpl::default(); let l1 = List::new(&db, 1, None); let l2 = List::new(&db, 2, Some(l1)); diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index 2075a8e7..e9ac47b3 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -4,15 +4,12 @@ //! are the accumulated values from another query. mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; -use salsa::prelude::*; +use salsa::{prelude::*, DatabaseImpl}; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct List { value: u32, @@ -23,7 +20,7 @@ struct List { struct Integers(u32); #[salsa::tracked] -fn compute(db: &dyn Db, input: List) -> u32 { +fn compute(db: &dyn LogDatabase, input: List) -> u32 { db.push_log(format!("compute({:?})", input,)); // always pushes 0 @@ -41,30 +38,9 @@ fn compute(db: &dyn Db, input: List) -> u32 { result } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn test1() { - let mut db = Database::default(); + let mut db = DatabaseImpl::with(Logger::default()); let l1 = List::new(&db, 1, None); let l2 = List::new(&db, 2, Some(l1)); diff --git a/tests/accumulate.rs b/tests/accumulate.rs index 099128bd..e20cc05f 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -1,13 +1,10 @@ mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; use salsa::{Accumulator, Setter}; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct MyInput { field_a: u32, @@ -18,7 +15,7 @@ struct MyInput { struct Log(#[allow(dead_code)] String); #[salsa::tracked] -fn push_logs(db: &dyn Db, input: MyInput) { +fn push_logs(db: &dyn LogDatabase, input: MyInput) { db.push_log(format!( "push_logs(a = {}, b = {})", input.field_a(db), @@ -37,7 +34,7 @@ fn push_logs(db: &dyn Db, input: MyInput) { } #[salsa::tracked] -fn push_a_logs(db: &dyn Db, input: MyInput) { +fn push_a_logs(db: &dyn LogDatabase, input: MyInput) { let field_a = input.field_a(db); db.push_log(format!("push_a_logs({})", field_a)); @@ -47,7 +44,7 @@ fn push_a_logs(db: &dyn Db, input: MyInput) { } #[salsa::tracked] -fn push_b_logs(db: &dyn Db, input: MyInput) { +fn push_b_logs(db: &dyn LogDatabase, input: MyInput) { let field_a = input.field_b(db); db.push_log(format!("push_b_logs({})", field_a)); @@ -56,30 +53,9 @@ fn push_b_logs(db: &dyn Db, input: MyInput) { } } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn accumulate_once() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::with(Logger::default()); // Just call accumulate on a base input to see what happens. let input = MyInput::new(&db, 2, 3); @@ -115,7 +91,7 @@ fn accumulate_once() { #[test] fn change_a_from_2_to_0() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::with(Logger::default()); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); @@ -170,7 +146,7 @@ fn change_a_from_2_to_0() { #[test] fn change_a_from_2_to_1() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::with(Logger::default()); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); @@ -229,7 +205,7 @@ fn change_a_from_2_to_1() { #[test] fn get_a_logs_after_changing_b() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::with(Logger::default()); // Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter) let input = MyInput::new(&db, 2, 3); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 8e5d51ef..cb741f37 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,16 +2,21 @@ #![allow(dead_code)] +use salsa::{DatabaseImpl, UserData}; + +/// Logging userdata: provides [`LogDatabase`][] trait. +/// +/// If you wish to use it along with other userdata, +/// you can also embed it in another struct and implement [`HasLogger`][] for that struct. #[derive(Default)] pub struct Logger { logs: std::sync::Mutex>, } -/// Trait implemented by databases that lets them log events. -pub trait HasLogger { - /// Return a reference to the logger from the database. - fn logger(&self) -> &Logger; +impl UserData for Logger {} +#[salsa::db] +pub trait LogDatabase: HasLogger + salsa::Database { /// Log an event from inside a tracked function. fn push_log(&self, string: String) { self.logger().logs.lock().unwrap().push(string); @@ -33,3 +38,86 @@ pub trait HasLogger { assert_eq!(logs.len(), expected); } } + +#[salsa::db] +impl LogDatabase for DatabaseImpl {} + +/// Trait implemented by databases that lets them log events. +pub trait HasLogger { + /// Return a reference to the logger from the database. + fn logger(&self) -> &Logger; +} + +impl HasLogger for DatabaseImpl { + fn logger(&self) -> &Logger { + U::logger(self) + } +} + +impl HasLogger for Logger { + fn logger(&self) -> &Logger { + self + } +} + +/// Userdata that provides logging and logs salsa events. +#[derive(Default)] +pub struct EventLogger { + logger: Logger, +} + +impl UserData for EventLogger { + fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { + db.push_log(format!("{:?}", event())); + } +} + +impl HasLogger for EventLogger { + fn logger(&self) -> &Logger { + &self.logger + } +} + +#[derive(Default)] +pub struct DiscardLogger(Logger); + +impl UserData for DiscardLogger { + fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { + let event = event(); + match event.kind { + salsa::EventKind::WillDiscardStaleOutput { .. } + | salsa::EventKind::DidDiscard { .. } => { + db.push_log(format!("salsa_event({:?})", event.kind)); + } + _ => {} + } + } +} + +impl HasLogger for DiscardLogger { + fn logger(&self) -> &Logger { + &self.0 + } +} + +#[derive(Default)] +pub struct ExecuteValidateLogger(Logger); + +impl UserData for ExecuteValidateLogger { + fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { + let event = event(); + match event.kind { + salsa::EventKind::WillExecute { .. } + | salsa::EventKind::DidValidateMemoizedValue { .. } => { + db.push_log(format!("salsa_event({:?})", event.kind)); + } + _ => {} + } + } +} + +impl HasLogger for ExecuteValidateLogger { + fn logger(&self) -> &Logger { + &self.0 + } +} diff --git a/tests/compile-fail/get-set-on-private-input-field.rs b/tests/compile-fail/get-set-on-private-input-field.rs index ae1dd75e..5ecec583 100644 --- a/tests/compile-fail/get-set-on-private-input-field.rs +++ b/tests/compile-fail/get-set-on-private-input-field.rs @@ -8,7 +8,7 @@ mod a { } fn main() { - let mut db = salsa::default_database(); + let mut db = salsa::DatabaseImpl::new(); let input = a::MyInput::new(&mut db, 22); input.field(&db); diff --git a/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs b/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs index 0860efc3..3dbb4f2f 100644 --- a/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs +++ b/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs @@ -16,7 +16,7 @@ fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'d } fn main() { - let mut db = salsa::default_database(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); let tracked = tracked_fn(&db, input); input.set_field(&mut db).to(24); diff --git a/tests/compile-fail/span-input-setter.rs b/tests/compile-fail/span-input-setter.rs index 1f4a4513..9abf4b6c 100644 --- a/tests/compile-fail/span-input-setter.rs +++ b/tests/compile-fail/span-input-setter.rs @@ -4,7 +4,7 @@ pub struct MyInput { } fn main() { - let mut db = salsa::default_database(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&mut db, 22); input.field(&db); input.set_field(22); diff --git a/tests/compile-fail/span-tracked-getter.rs b/tests/compile-fail/span-tracked-getter.rs index d588c140..245fd100 100644 --- a/tests/compile-fail/span-tracked-getter.rs +++ b/tests/compile-fail/span-tracked-getter.rs @@ -10,6 +10,6 @@ fn my_fn(db: &dyn salsa::Database) { } fn main() { - let mut db = salsa::default_database(); + let mut db = salsa::DatabaseImpl::new(); my_fn(&db); } diff --git a/tests/compile-fail/span-tracked-getter.stderr b/tests/compile-fail/span-tracked-getter.stderr index 8ae7219e..fcf546c7 100644 --- a/tests/compile-fail/span-tracked-getter.stderr +++ b/tests/compile-fail/span-tracked-getter.stderr @@ -24,7 +24,7 @@ help: consider borrowing here warning: variable does not need to be mutable --> tests/compile-fail/span-tracked-getter.rs:13:9 | -13 | let mut db = salsa::default_database(); +13 | let mut db = salsa::DatabaseImpl::new(); | ----^^ | | | help: remove this `mut` diff --git a/tests/cycles.rs b/tests/cycles.rs index a37d6f6c..1cc2f06a 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -3,6 +3,7 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; use expect_test::expect; +use salsa::DatabaseImpl; use salsa::Durability; // Axes: @@ -59,17 +60,6 @@ struct Error { use salsa::Database as Db; use salsa::Setter; -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - -impl RefUnwindSafe for Database {} - #[salsa::input] struct MyInput {} @@ -169,7 +159,7 @@ fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { #[test] fn cycle_memoized() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db); let cycle = extract_cycle(|| memoized_a(db, input)); let expected = expect![[r#" @@ -184,7 +174,7 @@ fn cycle_memoized() { #[test] fn cycle_volatile() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db); let cycle = extract_cycle(|| volatile_a(db, input)); let expected = expect![[r#" @@ -203,7 +193,7 @@ fn expect_cycle() { // ^ | // +-----+ - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); assert!(cycle_a(db, abc).is_err()); }) @@ -214,7 +204,7 @@ fn inner_cycle() { // A --> B <-- C // ^ | // +-----+ - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); let err = cycle_c(db, abc); assert!(err.is_err()); @@ -233,7 +223,7 @@ fn cycle_revalidate() { // A --> B // ^ | // +-----+ - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); assert!(cycle_a(&db, abc).is_err()); abc.set_b(&mut db).to(CycleQuery::A); // same value as default @@ -245,7 +235,7 @@ fn cycle_recovery_unchanged_twice() { // A --> B // ^ | // +-----+ - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); assert!(cycle_a(&db, abc).is_err()); @@ -255,8 +245,7 @@ fn cycle_recovery_unchanged_twice() { #[test] fn cycle_appears() { - let mut db = Database::default(); - + let mut db = salsa::DatabaseImpl::new(); // A --> B let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); assert!(cycle_a(&db, abc).is_ok()); @@ -270,7 +259,7 @@ fn cycle_appears() { #[test] fn cycle_disappears() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); // A --> B // ^ | @@ -289,7 +278,7 @@ fn cycle_disappears() { /// the fact that the cycle will no longer occur. #[test] fn cycle_disappears_durability() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let abc = ABC::new( &mut db, CycleQuery::None, @@ -320,7 +309,7 @@ fn cycle_disappears_durability() { #[test] fn cycle_mixed_1() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { // A --> B <-- C // | ^ // +-----+ @@ -338,7 +327,7 @@ fn cycle_mixed_1() { #[test] fn cycle_mixed_2() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { // Configuration: // // A --> B --> C @@ -360,7 +349,7 @@ fn cycle_mixed_2() { fn cycle_deterministic_order() { // No matter whether we start from A or B, we get the same set of participants: let f = || { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); // A --> B // ^ | @@ -390,7 +379,7 @@ fn cycle_deterministic_order() { #[test] fn cycle_multiple() { // No matter whether we start from A or B, we get the same set of participants: - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); // Configuration: // @@ -432,7 +421,7 @@ fn cycle_multiple() { #[test] fn cycle_recovery_set_but_not_participating() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { // A --> C -+ // ^ | // +--+ diff --git a/tests/debug.rs b/tests/debug.rs index f1bde114..194c6511 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -1,7 +1,7 @@ //! Test that `DeriveWithDb` is correctly derived. use expect_test::expect; -use salsa::{Database as _, Setter}; +use salsa::{Database, Setter}; #[salsa::input] struct MyInput { @@ -19,18 +19,9 @@ struct ComplexStruct { not_salsa: NotSalsa, } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn input() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 22); let not_salsa = NotSalsa { field: "it's salsa time".to_string(), @@ -54,7 +45,7 @@ fn leak_debug_string(_db: &dyn salsa::Database, input: MyInput) -> String { /// Don't try this at home, kids. #[test] fn untracked_dependencies() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); @@ -95,7 +86,7 @@ fn leak_derived_custom(db: &dyn salsa::Database, input: MyInput, value: u32) -> #[test] fn custom_debug_impl() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); diff --git a/tests/deletion-cascade.rs b/tests/deletion-cascade.rs index 8385d048..023e584b 100644 --- a/tests/deletion-cascade.rs +++ b/tests/deletion-cascade.rs @@ -3,22 +3,19 @@ //! * when we delete memoized data, also delete outputs from that data mod common; -use common::{HasLogger, Logger}; +use common::{DiscardLogger, LogDatabase}; use expect_test::expect; -use salsa::Setter; +use salsa::{DatabaseImpl, Setter}; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input(singleton)] struct MyInput { field: u32, } #[salsa::tracked] -fn final_result(db: &dyn Db, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result({:?})", input)); let mut sum = 0; for tracked_struct in create_tracked_structs(db, input) { @@ -33,7 +30,7 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn create_tracked_structs(db: &dyn Db, input: MyInput) -> Vec> { +fn create_tracked_structs(db: &dyn LogDatabase, input: MyInput) -> Vec> { db.push_log(format!("intermediate_result({:?})", input)); (0..input.field(db)) .map(|i| MyTracked::new(db, i)) @@ -41,49 +38,19 @@ fn create_tracked_structs(db: &dyn Db, input: MyInput) -> Vec> { } #[salsa::tracked] -fn contribution_from_struct<'db>(db: &'db dyn Db, tracked: MyTracked<'db>) -> u32 { +fn contribution_from_struct<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { let m = MyTracked::new(db, tracked.field(db)); copy_field(db, m) * 2 } #[salsa::tracked] -fn copy_field<'db>(db: &'db dyn Db, tracked: MyTracked<'db>) -> u32 { +fn copy_field<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { tracked.field(db) } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillDiscardStaleOutput { .. } - | salsa::EventKind::DidDiscard { .. } => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - _ => {} - } - } -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn basic() { - let mut db = Database::default(); + let mut db: DatabaseImpl = Default::default(); // Creates 3 tracked structs let input = MyInput::new(&db, 3); diff --git a/tests/deletion-drops.rs b/tests/deletion-drops.rs index 989f934f..b03ceda7 100644 --- a/tests/deletion-drops.rs +++ b/tests/deletion-drops.rs @@ -56,7 +56,7 @@ impl MyInput { #[test] fn deletion_drops() { - let mut db = salsa::default_database(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); diff --git a/tests/deletion.rs b/tests/deletion.rs index cdcc14bf..74b69547 100644 --- a/tests/deletion.rs +++ b/tests/deletion.rs @@ -3,22 +3,19 @@ //! * entities not created in a revision are deleted, as is any memoized data keyed on them. mod common; -use common::{HasLogger, Logger}; +use common::LogDatabase; use expect_test::expect; use salsa::Setter; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct MyInput { field: u32, } #[salsa::tracked] -fn final_result(db: &dyn Db, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result({:?})", input)); let mut sum = 0; for tracked_struct in create_tracked_structs(db, input) { @@ -33,7 +30,7 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn create_tracked_structs(db: &dyn Db, input: MyInput) -> Vec> { +fn create_tracked_structs(db: &dyn LogDatabase, input: MyInput) -> Vec> { db.push_log(format!("intermediate_result({:?})", input)); (0..input.field(db)) .map(|i| MyTracked::new(db, i)) @@ -41,43 +38,13 @@ fn create_tracked_structs(db: &dyn Db, input: MyInput) -> Vec> { } #[salsa::tracked] -fn contribution_from_struct<'db>(db: &'db dyn Db, tracked: MyTracked<'db>) -> u32 { +fn contribution_from_struct<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { tracked.field(db) * 2 } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillDiscardStaleOutput { .. } - | salsa::EventKind::DidDiscard { .. } => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - _ => {} - } - } -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn basic() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); // Creates 3 tracked structs let input = MyInput::new(&db, 3); diff --git a/tests/elided-lifetime-in-tracked-fn.rs b/tests/elided-lifetime-in-tracked-fn.rs index bd0e2184..07090d60 100644 --- a/tests/elided-lifetime-in-tracked-fn.rs +++ b/tests/elided-lifetime-in-tracked-fn.rs @@ -2,22 +2,19 @@ //! compiles and executes successfully. mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; -use salsa::Setter; +use salsa::{DatabaseImpl, Setter}; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct MyInput { field: u32, } #[salsa::tracked] -fn final_result(db: &dyn Db, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result({:?})", input)); intermediate_result(db, input).field(db) * 2 } @@ -28,33 +25,14 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { db.push_log(format!("intermediate_result({:?})", input)); MyTracked::new(db, input.field(db) / 2) } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { - let mut db = Database::default(); + let mut db: DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); diff --git a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs index d9fb6d52..5c110c21 100644 --- a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs @@ -4,13 +4,10 @@ #![allow(dead_code)] mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; -use salsa::Setter; - -#[salsa::db] -trait Db: salsa::Database + HasLogger {} +use salsa::{DatabaseImpl, Setter}; #[salsa::input] struct MyInput { @@ -18,13 +15,13 @@ struct MyInput { } #[salsa::tracked] -fn final_result_depends_on_x(db: &dyn Db, input: MyInput) -> u32 { +fn final_result_depends_on_x(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result_depends_on_x({:?})", input)); intermediate_result(db, input).x(db) * 2 } #[salsa::tracked] -fn final_result_depends_on_y(db: &dyn Db, input: MyInput) -> u32 { +fn final_result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result_depends_on_y({:?})", input)); intermediate_result(db, input).y(db) * 2 } @@ -36,36 +33,17 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { MyTracked::new(db, (input.field(db) + 1) / 2, input.field(db) / 2) } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { // x = (input.field + 1) / 2 // y = input.field / 2 // final_result_depends_on_x = x * 2 = (input.field + 1) / 2 * 2 // final_result_depends_on_y = y * 2 = input.field / 2 * 2 - let mut db = Database::default(); + let mut db: DatabaseImpl = Default::default(); // intermediate results: // x = (22 + 1) / 2 = 11 diff --git a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs index c1c2df18..12f2cfbf 100644 --- a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs @@ -4,14 +4,11 @@ #![allow(dead_code)] mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; use salsa::Setter; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct MyInput { x: u32, @@ -19,41 +16,21 @@ struct MyInput { } #[salsa::tracked] -fn result_depends_on_x(db: &dyn Db, input: MyInput) -> u32 { +fn result_depends_on_x(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("result_depends_on_x({:?})", input)); input.x(db) + 1 } #[salsa::tracked] -fn result_depends_on_y(db: &dyn Db, input: MyInput) -> u32 { +fn result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("result_depends_on_y({:?})", input)); input.y(db) - 1 } - -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { // result_depends_on_x = x + 1 // result_depends_on_y = y - 1 - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22, 33); assert_eq!(result_depends_on_x(&db, input), 23); diff --git a/tests/hello_world.rs b/tests/hello_world.rs index 63e3d5ce..3a316d1e 100644 --- a/tests/hello_world.rs +++ b/tests/hello_world.rs @@ -2,22 +2,19 @@ //! compiles and executes successfully. mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; use salsa::Setter; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct MyInput { field: u32, } #[salsa::tracked] -fn final_result(db: &dyn Db, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result({:?})", input)); intermediate_result(db, input).field(db) * 2 } @@ -28,33 +25,14 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { db.push_log(format!("intermediate_result({:?})", input)); MyTracked::new(db, input.field(db) / 2) } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); @@ -85,7 +63,7 @@ fn execute() { /// Create and mutate a distinct input. No re-execution required. #[test] fn red_herring() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); diff --git a/tests/interned-struct-with-lifetime.rs b/tests/interned-struct-with-lifetime.rs index aa06678a..a74d2c42 100644 --- a/tests/interned-struct-with-lifetime.rs +++ b/tests/interned-struct-with-lifetime.rs @@ -1,14 +1,9 @@ //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. -mod common; -use common::{HasLogger, Logger}; use expect_test::expect; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::interned] struct InternedString<'db> { data: String, @@ -20,38 +15,17 @@ struct InternedPair<'db> { } #[salsa::tracked] -fn intern_stuff(db: &dyn Db) -> String { +fn intern_stuff(db: &dyn salsa::Database) -> String { let s1 = InternedString::new(db, "Hello, ".to_string()); let s2 = InternedString::new(db, "World, ".to_string()); let s3 = InternedPair::new(db, (s1, s2)); format!("{s3:?}") } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { - let mut db = Database::default(); - + let db = salsa::DatabaseImpl::new(); expect![[r#" "InternedPair { data: (InternedString { data: \"Hello, \" }, InternedString { data: \"World, \" }) }" "#]].assert_debug_eq(&intern_stuff(&db)); - db.assert_logs(expect!["[]"]); } diff --git a/tests/is_send_sync.rs b/tests/is_send_sync.rs index 9fa994bf..6ada1bac 100644 --- a/tests/is_send_sync.rs +++ b/tests/is_send_sync.rs @@ -1,23 +1,9 @@ //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. +use salsa::Database; use test_log::test; -#[salsa::db] -trait Db: salsa::Database {} - -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - #[salsa::input] struct MyInput { field: String, @@ -34,7 +20,7 @@ struct MyInterned<'db> { } #[salsa::tracked] -fn test(db: &dyn crate::Db, input: MyInput) { +fn test(db: &dyn Database, input: MyInput) { let input = is_send_sync(input); let interned = is_send_sync(MyInterned::new(db, input.field(db).clone())); let _tracked_struct = is_send_sync(MyTracked::new(db, interned)); @@ -46,7 +32,7 @@ fn is_send_sync(t: T) -> T { #[test] fn execute() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, "Hello".to_string()); test(&db, input); } diff --git a/tests/lru.rs b/tests/lru.rs index f45539b0..2d5b1b5d 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -6,14 +6,11 @@ use std::sync::{ Arc, }; -use salsa::Database as _; mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; +use salsa::Database as _; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[derive(Debug, PartialEq, Eq)] struct HotPotato(u32); @@ -40,50 +37,31 @@ struct MyInput { } #[salsa::tracked(lru = 32)] -fn get_hot_potato(db: &dyn Db, input: MyInput) -> Arc { +fn get_hot_potato(db: &dyn LogDatabase, input: MyInput) -> Arc { db.push_log(format!("get_hot_potato({:?})", input.field(db))); Arc::new(HotPotato::new(input.field(db))) } #[salsa::tracked] -fn get_hot_potato2(db: &dyn Db, input: MyInput) -> u32 { +fn get_hot_potato2(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("get_hot_potato2({:?})", input.field(db))); get_hot_potato(db, input).0 } #[salsa::tracked(lru = 32)] -fn get_volatile(db: &dyn Db, _input: MyInput) -> usize { +fn get_volatile(db: &dyn LogDatabase, _input: MyInput) -> usize { static COUNTER: AtomicUsize = AtomicUsize::new(0); db.report_untracked_read(); COUNTER.fetch_add(1, Ordering::SeqCst) } -#[salsa::db] -#[derive(Default)] -struct DatabaseImpl { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for DatabaseImpl {} - -#[salsa::db] -impl Db for DatabaseImpl {} - -impl HasLogger for DatabaseImpl { - fn logger(&self) -> &Logger { - &self.logger - } -} - fn load_n_potatoes() -> usize { N_POTATOES.with(|n| n.load(Ordering::SeqCst)) } #[test] fn lru_works() { - let db = DatabaseImpl::default(); + let db: salsa::DatabaseImpl = Default::default(); assert_eq!(load_n_potatoes(), 0); for i in 0..128u32 { @@ -99,7 +77,7 @@ fn lru_works() { #[test] fn lru_doesnt_break_volatile_queries() { - let db = DatabaseImpl::default(); + let db: salsa::DatabaseImpl = Default::default(); // Create all inputs first, so that there are no revision changes among calls to `get_volatile` let inputs: Vec = (0..128usize).map(|i| MyInput::new(&db, i as u32)).collect(); @@ -117,7 +95,7 @@ fn lru_doesnt_break_volatile_queries() { #[test] fn lru_can_be_changed_at_runtime() { - let db = DatabaseImpl::default(); + let db: salsa::DatabaseImpl = Default::default(); assert_eq!(load_n_potatoes(), 0); let inputs: Vec<(u32, MyInput)> = (0..128).map(|i| (i, MyInput::new(&db, i))).collect(); @@ -160,7 +138,7 @@ fn lru_can_be_changed_at_runtime() { #[test] fn lru_keeps_dependency_info() { - let mut db = DatabaseImpl::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let capacity = 32; // Invoke `get_hot_potato2` 33 times. This will (in turn) invoke diff --git a/tests/mutate_in_place.rs b/tests/mutate_in_place.rs index d8fa88b0..047373ee 100644 --- a/tests/mutate_in_place.rs +++ b/tests/mutate_in_place.rs @@ -1,9 +1,6 @@ //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. -mod common; -use common::{HasLogger, Logger}; - use salsa::Setter; use test_log::test; @@ -12,25 +9,9 @@ struct MyInput { field: String, } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, "Hello".to_string()); diff --git a/tests/override_new_get_set.rs b/tests/override_new_get_set.rs index a5604962..9f3a8752 100644 --- a/tests/override_new_get_set.rs +++ b/tests/override_new_get_set.rs @@ -66,17 +66,5 @@ impl<'db> MyTracked<'db> { #[test] fn execute() { - #[salsa::db] - #[derive(Default)] - struct Database { - storage: salsa::Storage, - } - - #[salsa::db] - impl salsa::Database for Database {} - - #[salsa::db] - impl Db for Database {} - - let mut db = Database::default(); + salsa::DatabaseImpl::new(); } diff --git a/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs b/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs index d59741aa..32b444c7 100644 --- a/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs +++ b/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs @@ -6,20 +6,11 @@ struct MyTracked<'db> { field: u32, } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] #[should_panic( expected = "cannot create a tracked struct disambiguator outside of a tracked function" )] fn execute() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); MyTracked::new(&db, 0); } diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index 715f8a72..0e35ab25 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -3,17 +3,12 @@ //! both intra and cross thread. use salsa::Cancelled; +use salsa::DatabaseImpl; use salsa::Handle; use salsa::Setter; -use crate::setup::Database; use crate::setup::Knobs; - -#[salsa::db] -pub(crate) trait Db: salsa::Database + Knobs {} - -#[salsa::db] -impl Db for T {} +use crate::setup::KnobsDatabase; #[salsa::input] struct MyInput { @@ -21,14 +16,14 @@ struct MyInput { } #[salsa::tracked] -fn a1(db: &dyn Db, input: MyInput) -> MyInput { +fn a1(db: &dyn KnobsDatabase, input: MyInput) -> MyInput { db.signal(1); db.wait_for(2); dummy(db, input) } #[salsa::tracked] -fn dummy(_db: &dyn Db, _input: MyInput) -> MyInput { +fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { panic!("should never get here!") } @@ -49,7 +44,7 @@ fn dummy(_db: &dyn Db, _input: MyInput) -> MyInput { #[test] fn execute() { - let mut db = Handle::new(Database::default()); + let mut db = Handle::new(>::default()); db.knobs().signal_on_will_block.store(3); let input = MyInput::new(&*db, 1); diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs index 32dbbf6f..7706d6ec 100644 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -2,16 +2,11 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. +use salsa::DatabaseImpl; use salsa::Handle; -use crate::setup::Database; use crate::setup::Knobs; - -#[salsa::db] -pub(crate) trait Db: salsa::Database + Knobs {} - -#[salsa::db] -impl Db for T {} +use crate::setup::KnobsDatabase; #[salsa::input] pub(crate) struct MyInput { @@ -19,7 +14,7 @@ pub(crate) struct MyInput { } #[salsa::tracked(recovery_fn = recover_a1)] -pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // Wait to create the cycle until both threads have entered db.signal(1); db.wait_for(2); @@ -27,23 +22,23 @@ pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { a2(db, input) } -fn recover_a1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_a1"); key.field(db) * 10 + 1 } #[salsa::tracked(recovery_fn=recover_a2)] -pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { b1(db, input) } -fn recover_a2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_a2"); key.field(db) * 10 + 2 } #[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // Wait to create the cycle until both threads have entered db.wait_for(1); db.signal(2); @@ -53,17 +48,17 @@ pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { b2(db, input) } -fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_b1"); key.field(db) * 20 + 1 } #[salsa::tracked(recovery_fn=recover_b2)] -pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { a1(db, input) } -fn recover_b2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_b2"); key.field(db) * 20 + 2 } @@ -92,7 +87,7 @@ fn recover_b2(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(Database::default()); + let db = Handle::new(>::default()); db.knobs().signal_on_will_block.store(3); let input = MyInput::new(&*db, 1); diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs index 867a6be7..0c5e3475 100644 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -2,16 +2,9 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::Handle; +use salsa::{DatabaseImpl, Handle}; -use crate::setup::Database; -use crate::setup::Knobs; - -#[salsa::db] -pub(crate) trait Db: salsa::Database + Knobs {} - -#[salsa::db] -impl Db for T {} +use crate::setup::{Knobs, KnobsDatabase}; #[salsa::input] pub(crate) struct MyInput { @@ -19,7 +12,7 @@ pub(crate) struct MyInput { } #[salsa::tracked] -pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // tell thread b we have started db.signal(1); @@ -30,25 +23,25 @@ pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { } #[salsa::tracked] -pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // create the cycle b1(db, input) } #[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // wait for thread a to have started db.wait_for(1); b2(db, input) } -fn recover_b1(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_b1"); key.field(db) * 20 + 2 } #[salsa::tracked] -pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // will encounter a cycle but recover b3(db, input); b1(db, input); // hasn't recovered yet @@ -56,12 +49,12 @@ pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { } #[salsa::tracked(recovery_fn=recover_b3)] -pub(crate) fn b3(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // will block on thread a, signaling stage 2 a1(db, input) } -fn recover_b3(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover_b3"); key.field(db) * 200 + 2 } @@ -88,7 +81,7 @@ fn recover_b3(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(Database::default()); + let db = Handle::new(>::default()); db.knobs().signal_on_will_block.store(3); let input = MyInput::new(&*db, 1); diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs index 59036685..39b6299c 100644 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -2,25 +2,20 @@ //! See the `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use crate::setup::Database; use crate::setup::Knobs; +use crate::setup::KnobsDatabase; use expect_test::expect; use salsa::Database as _; +use salsa::DatabaseImpl; use salsa::Handle; -#[salsa::db] -pub(crate) trait Db: salsa::Database + Knobs {} - -#[salsa::db] -impl Db for T {} - #[salsa::input] pub(crate) struct MyInput { field: i32, } #[salsa::tracked] -pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // Wait to create the cycle until both threads have entered db.signal(1); db.wait_for(2); @@ -29,7 +24,7 @@ pub(crate) fn a(db: &dyn Db, input: MyInput) -> i32 { } #[salsa::tracked] -pub(crate) fn b(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // Wait to create the cycle until both threads have entered db.wait_for(1); db.signal(2); @@ -43,7 +38,7 @@ pub(crate) fn b(db: &dyn Db, input: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(Database::default()); + let db = Handle::new(>::default()); db.knobs().signal_on_will_block.store(3); let input = MyInput::new(&*db, -1); diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs index 044fe826..7a32d95c 100644 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ b/tests/parallel/parallel_cycle_one_recover.rs @@ -2,16 +2,9 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::Handle; +use salsa::{DatabaseImpl, Handle}; -use crate::setup::Database; -use crate::setup::Knobs; - -#[salsa::db] -pub(crate) trait Db: salsa::Database + Knobs {} - -#[salsa::db] -impl Db for T {} +use crate::setup::{Knobs, KnobsDatabase}; #[salsa::input] pub(crate) struct MyInput { @@ -19,7 +12,7 @@ pub(crate) struct MyInput { } #[salsa::tracked] -pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // Wait to create the cycle until both threads have entered db.signal(1); db.wait_for(2); @@ -28,17 +21,17 @@ pub(crate) fn a1(db: &dyn Db, input: MyInput) -> i32 { } #[salsa::tracked(recovery_fn=recover)] -pub(crate) fn a2(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { b1(db, input) } -fn recover(db: &dyn Db, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { dbg!("recover"); key.field(db) * 20 + 2 } #[salsa::tracked] -pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { // Wait to create the cycle until both threads have entered db.wait_for(1); db.signal(2); @@ -49,7 +42,7 @@ pub(crate) fn b1(db: &dyn Db, input: MyInput) -> i32 { } #[salsa::tracked] -pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { +pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { a1(db, input) } @@ -77,7 +70,7 @@ pub(crate) fn b2(db: &dyn Db, input: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(Database::default()); + let db = Handle::new(>::default()); db.knobs().signal_on_will_block.store(3); let input = MyInput::new(&*db, 1); diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index c7c00ee8..6410f853 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -1,11 +1,13 @@ use crossbeam::atomic::AtomicCell; +use salsa::{Database, DatabaseImpl, UserData}; use crate::signal::Signal; /// Various "knobs" and utilities used by tests to force /// a certain behavior. -pub(crate) trait Knobs { - fn knobs(&self) -> &KnobsStruct; +#[salsa::db] +pub(crate) trait KnobsDatabase: Database { + fn knobs(&self) -> &Knobs; fn signal(&self, stage: usize); @@ -16,7 +18,7 @@ pub(crate) trait Knobs { /// behave on one specific thread. Note that this state is /// intentionally thread-local (apart from `signal`). #[derive(Default)] -pub(crate) struct KnobsStruct { +pub(crate) struct Knobs { /// A kind of flexible barrier used to coordinate execution across /// threads to ensure we reach various weird states. pub(crate) signal: Signal, @@ -28,39 +30,32 @@ pub(crate) struct KnobsStruct { pub(crate) signal_on_did_cancel: AtomicCell, } -#[salsa::db] -#[derive(Default)] -pub(crate) struct Database { - storage: salsa::Storage, - knobs: KnobsStruct, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { +impl UserData for Knobs { + fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { let event = event(); match event.kind { salsa::EventKind::WillBlockOn { .. } => { - self.signal(self.knobs().signal_on_will_block.load()); + db.signal(db.signal_on_will_block.load()); } salsa::EventKind::DidSetCancellationFlag => { - self.signal(self.knobs().signal_on_did_cancel.load()); + db.signal(db.signal_on_did_cancel.load()); } _ => {} } } } -impl Knobs for Database { - fn knobs(&self) -> &KnobsStruct { - &self.knobs +#[salsa::db] +impl KnobsDatabase for DatabaseImpl { + fn knobs(&self) -> &Knobs { + self } fn signal(&self, stage: usize) { - self.knobs.signal.signal(stage); + self.signal.signal(stage); } fn wait_for(&self, stage: usize) { - self.knobs.signal.wait_for(stage); + self.signal.wait_for(stage); } } diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 2e311c6f..2c5bdfd5 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -5,41 +5,14 @@ use std::cell::Cell; use expect_test::expect; mod common; -use common::{HasLogger, Logger}; -use salsa::Setter; +use common::{EventLogger, LogDatabase}; +use salsa::{Database, Setter}; use test_log::test; thread_local! { static COUNTER: Cell = const { Cell::new(0) }; } -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - self.push_log(format!("{event:?}")); - } -} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[salsa::input] struct MyInput { field1: u32, @@ -52,7 +25,7 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn function(db: &dyn Db, input: MyInput) -> usize { +fn function(db: &dyn Database, input: MyInput) -> usize { // Read input 1 let _field1 = input.field1(db); @@ -71,7 +44,7 @@ fn function(db: &dyn Db, input: MyInput) -> usize { #[test] fn test_leaked_inputs_ignored() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, 10, 20); let result_in_rev_1 = function(&db, input); diff --git a/tests/singleton.rs b/tests/singleton.rs index 041eb309..539d48cb 100644 --- a/tests/singleton.rs +++ b/tests/singleton.rs @@ -3,8 +3,6 @@ //! Singleton structs are created only once. Subsequent `get`s and `new`s after creation return the same `Id`. use expect_test::expect; -mod common; -use common::{HasLogger, Logger}; use salsa::Database as _; use test_log::test; @@ -15,25 +13,9 @@ struct MyInput { id_field: u16, } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn basic() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); let input1 = MyInput::new(&db, 3, 4); let input2 = MyInput::get(&db); @@ -46,7 +28,7 @@ fn basic() { #[test] #[should_panic] fn twice() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); let input1 = MyInput::new(&db, 3, 4); let input2 = MyInput::get(&db); @@ -58,7 +40,7 @@ fn twice() { #[test] fn debug() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 3, 4); let actual = format!("{:?}", input); let expected = expect!["MyInput { [salsa id]: Id(0), field: 3, id_field: 4 }"]; diff --git a/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs b/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs index f82a684c..a407aee6 100644 --- a/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs +++ b/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs @@ -2,9 +2,6 @@ //! compilation succeeds but execution panics #![allow(warnings)] -#[salsa::db] -trait Db: salsa::Database {} - #[salsa::input] struct MyInput { field: u32, @@ -16,12 +13,15 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_struct_created_in_another_query<'db>(db: &'db dyn Db, input: MyInput) -> MyTracked<'db> { +fn tracked_struct_created_in_another_query<'db>( + db: &'db dyn salsa::Database, + input: MyInput, +) -> MyTracked<'db> { MyTracked::new(db, input.field(db) * 2) } #[salsa::tracked] -fn tracked_fn<'db>(db: &'db dyn Db, input: MyInput) -> MyTracked<'db> { +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { let t = tracked_struct_created_in_another_query(db, input); if input.field(db) != 0 { tracked_fn_extra::specify(db, t, 2222); @@ -30,28 +30,16 @@ fn tracked_fn<'db>(db: &'db dyn Db, input: MyInput) -> MyTracked<'db> { } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(_db: &'db dyn Db, _input: MyTracked<'db>) -> u32 { +fn tracked_fn_extra<'db>(_db: &'db dyn salsa::Database, _input: MyTracked<'db>) -> u32 { 0 } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - #[test] #[should_panic( expected = "can only use `specify` on salsa structs created during the current tracked fn" )] fn execute_when_specified() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); let tracked = tracked_fn(&db, input); } diff --git a/tests/synthetic_write.rs b/tests/synthetic_write.rs index dd8d17a1..b7629280 100644 --- a/tests/synthetic_write.rs +++ b/tests/synthetic_write.rs @@ -4,12 +4,9 @@ mod common; -use common::{HasLogger, Logger}; +use common::{ExecuteValidateLogger, LogDatabase, Logger}; use expect_test::expect; -use salsa::{Database as _, Durability, Event, EventKind}; - -#[salsa::db] -trait Db: salsa::Database + HasLogger {} +use salsa::{Database, DatabaseImpl, Durability, Event, EventKind}; #[salsa::input] struct MyInput { @@ -17,47 +14,20 @@ struct MyInput { } #[salsa::tracked] -fn tracked_fn(db: &dyn Db, input: MyInput) -> u32 { +fn tracked_fn(db: &dyn Database, input: MyInput) -> u32 { input.field(db) * 2 } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: Event) { - if let EventKind::WillExecute { .. } | EventKind::DidValidateMemoizedValue { .. } = - event.kind - { - self.push_log(format!("{:?}", event.kind)); - } - } -} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - -#[salsa::db] -impl Db for Database {} - #[test] fn execute() { - let mut db = Database::default(); + let mut db: DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22); assert_eq!(tracked_fn(&db, input), 44); db.assert_logs(expect![[r#" [ - "WillExecute { database_key: tracked_fn(0) }", + "salsa_event(WillExecute { database_key: tracked_fn(0) })", ]"#]]); // Bumps the revision @@ -68,6 +38,6 @@ fn execute() { db.assert_logs(expect![[r#" [ - "DidValidateMemoizedValue { database_key: tracked_fn(0) }", + "salsa_event(DidValidateMemoizedValue { database_key: tracked_fn(0) })", ]"#]]); } diff --git a/tests/tracked-struct-id-field-bad-eq.rs b/tests/tracked-struct-id-field-bad-eq.rs index 49c4544d..b003a305 100644 --- a/tests/tracked-struct-id-field-bad-eq.rs +++ b/tests/tracked-struct-id-field-bad-eq.rs @@ -1,6 +1,6 @@ //! Test an id field whose `PartialEq` impl is always true. -use salsa::{Database as Db, Setter}; +use salsa::{Database, Setter}; use test_log::test; #[salsa::input] @@ -33,24 +33,14 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Db, input: MyInput) { +fn the_fn(db: &dyn Database, input: MyInput) { let tracked0 = MyTracked::new(db, BadEq::from(input.field(db))); assert_eq!(tracked0.field(db).field, input.field(db)); } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute() { - let mut db = Database::default(); - + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); the_fn(&db, input); input.set_field(&mut db).to(false); diff --git a/tests/tracked-struct-id-field-bad-hash.rs b/tests/tracked-struct-id-field-bad-hash.rs index c1f455ae..8a391b3b 100644 --- a/tests/tracked-struct-id-field-bad-hash.rs +++ b/tests/tracked-struct-id-field-bad-hash.rs @@ -42,18 +42,9 @@ fn the_fn(db: &dyn Db, input: MyInput) { assert_eq!(tracked0.field(db).field, input.field(db)); } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); the_fn(&db, input); diff --git a/tests/tracked-struct-unchanged-in-new-rev.rs b/tests/tracked-struct-unchanged-in-new-rev.rs index 7bbfb340..e4633740 100644 --- a/tests/tracked-struct-unchanged-in-new-rev.rs +++ b/tests/tracked-struct-unchanged-in-new-rev.rs @@ -16,18 +16,9 @@ fn tracked_fn(db: &dyn Db, input: MyInput) -> MyTracked<'_> { MyTracked::new(db, input.field(db) / 2) } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input1 = MyInput::new(&db, 22); let input2 = MyInput::new(&db, 44); diff --git a/tests/tracked-struct-value-field-bad-eq.rs b/tests/tracked-struct-value-field-bad-eq.rs index a9d50c6f..ec4cac5f 100644 --- a/tests/tracked-struct-value-field-bad-eq.rs +++ b/tests/tracked-struct-value-field-bad-eq.rs @@ -3,9 +3,9 @@ //! if we were to execute from scratch. use expect_test::expect; -use salsa::{Database as Db, Setter}; +use salsa::{Database, Setter}; mod common; -use common::{HasLogger, Logger}; +use common::LogDatabase; use test_log::test; #[salsa::input] @@ -37,51 +37,24 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Db, input: MyInput) -> bool { +fn the_fn(db: &dyn Database, input: MyInput) -> bool { let tracked = make_tracked_struct(db, input); read_tracked_struct(db, tracked) } #[salsa::tracked] -fn make_tracked_struct(db: &dyn Db, input: MyInput) -> MyTracked<'_> { +fn make_tracked_struct(db: &dyn Database, input: MyInput) -> MyTracked<'_> { MyTracked::new(db, BadEq::from(input.field(db))) } #[salsa::tracked] -fn read_tracked_struct<'db>(db: &'db dyn Db, tracked: MyTracked<'db>) -> bool { +fn read_tracked_struct<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> bool { tracked.field(db).field } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { - let event = event(); - match event.kind { - salsa::EventKind::WillExecute { .. } - | salsa::EventKind::DidValidateMemoizedValue { .. } => { - self.push_log(format!("salsa_event({:?})", event.kind)); - } - _ => {} - } - } -} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, true); let result = the_fn(&db, input); diff --git a/tests/tracked-struct-value-field-not-eq.rs b/tests/tracked-struct-value-field-not-eq.rs index f529f31a..eaf4a30c 100644 --- a/tests/tracked-struct-value-field-not-eq.rs +++ b/tests/tracked-struct-value-field-not-eq.rs @@ -2,7 +2,7 @@ //! This can our "last changed" data to be wrong //! but we *should* always reflect the final values. -use salsa::{Database as Db, Setter}; +use salsa::{Database, Setter}; use test_log::test; #[salsa::input] @@ -28,23 +28,14 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Db, input: MyInput) { +fn the_fn(db: &dyn Database, input: MyInput) { let tracked0 = MyTracked::new(db, NotEq::from(input.field(db))); assert_eq!(tracked0.field(db).field, input.field(db)); } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); the_fn(&db, input); diff --git a/tests/tracked_fn_constant.rs b/tests/tracked_fn_constant.rs index db443a79..b53f1b15 100644 --- a/tests/tracked_fn_constant.rs +++ b/tests/tracked_fn_constant.rs @@ -9,15 +9,6 @@ fn tracked_fn(db: &dyn salsa::Database) -> u32 { #[test] fn execute() { - #[salsa::db] - #[derive(Default)] - struct Database { - storage: salsa::Storage, - } - - #[salsa::db] - impl salsa::Database for Database {} - - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); assert_eq!(tracked_fn(&db), 44); } diff --git a/tests/tracked_fn_no_eq.rs b/tests/tracked_fn_no_eq.rs index 77d45050..ee7c5651 100644 --- a/tests/tracked_fn_no_eq.rs +++ b/tests/tracked_fn_no_eq.rs @@ -1,11 +1,8 @@ mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; -use salsa::Setter as _; - -#[salsa::db] -trait Db: salsa::Database + HasLogger {} +use salsa::{DatabaseImpl, Setter as _}; #[salsa::input] struct Input { @@ -13,7 +10,7 @@ struct Input { } #[salsa::tracked(no_eq)] -fn abs_float(db: &dyn Db, input: Input) -> f32 { +fn abs_float(db: &dyn LogDatabase, input: Input) -> f32 { let number = input.number(db); db.push_log(format!("abs_float({number})")); @@ -21,35 +18,15 @@ fn abs_float(db: &dyn Db, input: Input) -> f32 { } #[salsa::tracked] -fn derived(db: &dyn Db, input: Input) -> u32 { +fn derived(db: &dyn LogDatabase, input: Input) -> u32 { let x = abs_float(db, input); db.push_log("derived".to_string()); x as u32 } - -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - #[test] fn invoke() { - let mut db = Database::default(); + let mut db: DatabaseImpl = Default::default(); let input = Input::new(&db, 5); let x = derived(&db, input); diff --git a/tests/tracked_fn_on_input.rs b/tests/tracked_fn_on_input.rs index 2ea3fa6f..e588a40a 100644 --- a/tests/tracked_fn_on_input.rs +++ b/tests/tracked_fn_on_input.rs @@ -14,16 +14,7 @@ fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> u32 { #[test] fn execute() { - #[salsa::db] - #[derive(Default)] - struct Database { - storage: salsa::Storage, - } - - #[salsa::db] - impl salsa::Database for Database {} - - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); assert_eq!(tracked_fn(&db, input), 44); } diff --git a/tests/tracked_fn_on_interned.rs b/tests/tracked_fn_on_interned.rs index 852c3428..b551b880 100644 --- a/tests/tracked_fn_on_interned.rs +++ b/tests/tracked_fn_on_interned.rs @@ -11,18 +11,9 @@ fn tracked_fn<'db>(db: &'db dyn salsa::Database, name: Name<'db>) -> String { name.name(db).clone() } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); let name = Name::new(&db, "Salsa".to_string()); assert_eq!(tracked_fn(&db, name), "Salsa"); diff --git a/tests/tracked_fn_on_tracked.rs b/tests/tracked_fn_on_tracked.rs index 72aae634..967bbd55 100644 --- a/tests/tracked_fn_on_tracked.rs +++ b/tests/tracked_fn_on_tracked.rs @@ -16,18 +16,9 @@ fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> MyTracked<'_> { MyTracked::new(db, input.field(db) * 2) } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute() { - let db = Database::default(); + let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); assert_eq!(tracked_fn(&db, input).field(&db), 44); } diff --git a/tests/tracked_fn_on_tracked_specify.rs b/tests/tracked_fn_on_tracked_specify.rs index 9a76b498..70e4997a 100644 --- a/tests/tracked_fn_on_tracked_specify.rs +++ b/tests/tracked_fn_on_tracked_specify.rs @@ -26,18 +26,9 @@ fn tracked_fn_extra<'db>(_db: &'db dyn salsa::Database, _input: MyTracked<'db>) 0 } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn execute_when_specified() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); let tracked = tracked_fn(&db, input); assert_eq!(tracked.field(&db), 44); @@ -46,7 +37,7 @@ fn execute_when_specified() { #[test] fn execute_when_not_specified() { - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 0); let tracked = tracked_fn(&db, input); assert_eq!(tracked.field(&db), 0); diff --git a/tests/tracked_fn_read_own_entity.rs b/tests/tracked_fn_read_own_entity.rs index 588a221f..ad4a7002 100644 --- a/tests/tracked_fn_read_own_entity.rs +++ b/tests/tracked_fn_read_own_entity.rs @@ -3,20 +3,17 @@ use expect_test::expect; mod common; -use common::{HasLogger, Logger}; +use common::{LogDatabase, Logger}; use salsa::Setter; use test_log::test; -#[salsa::db] -trait Db: salsa::Database + HasLogger {} - #[salsa::input] struct MyInput { field: u32, } #[salsa::tracked] -fn final_result(db: &dyn Db, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("final_result({:?})", input)); intermediate_result(db, input).field(db) * 2 } @@ -27,35 +24,16 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn Db, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { db.push_log(format!("intermediate_result({:?})", input)); let tracked = MyTracked::new(db, input.field(db) / 2); let _ = tracked.field(db); // read the field of an entity we created tracked } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn one_entity() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); @@ -86,7 +64,7 @@ fn one_entity() { /// Create and mutate a distinct input. No re-execution required. #[test] fn red_herring() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = Default::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); diff --git a/tests/tracked_fn_read_own_specify.rs b/tests/tracked_fn_read_own_specify.rs index 63d2134e..0c643366 100644 --- a/tests/tracked_fn_read_own_specify.rs +++ b/tests/tracked_fn_read_own_specify.rs @@ -1,10 +1,7 @@ use expect_test::expect; -use salsa::Database as SalsaDatabase; mod common; -use common::{HasLogger, Logger}; - -#[salsa::db] -trait Db: salsa::Database + HasLogger {} +use common::{LogDatabase, Logger}; +use salsa::Database; #[salsa::input] struct MyInput { @@ -17,7 +14,7 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_fn(db: &dyn Db, input: MyInput) -> u32 { +fn tracked_fn(db: &dyn LogDatabase, input: MyInput) -> u32 { db.push_log(format!("tracked_fn({input:?})")); let t = MyTracked::new(db, input.field(db) * 2); tracked_fn_extra::specify(db, t, 2222); @@ -25,33 +22,14 @@ fn tracked_fn(db: &dyn Db, input: MyInput) -> u32 { } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(db: &dyn Db, input: MyTracked<'db>) -> u32 { +fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: MyTracked<'db>) -> u32 { db.push_log(format!("tracked_fn_extra({input:?})")); 0 } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, - logger: Logger, -} - -#[salsa::db] -impl salsa::Database for Database {} - -#[salsa::db] -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn execute() { - let mut db = Database::default(); + let mut db: salsa::DatabaseImpl = salsa::DatabaseImpl::default(); let input = MyInput::new(&db, 22); assert_eq!(tracked_fn(&db, input), 2222); db.assert_logs(expect![[r#" diff --git a/tests/tracked_fn_return_ref.rs b/tests/tracked_fn_return_ref.rs index d4cab330..ecd91a17 100644 --- a/tests/tracked_fn_return_ref.rs +++ b/tests/tracked_fn_return_ref.rs @@ -1,4 +1,4 @@ -use salsa::Database as _; +use salsa::Database; #[salsa::input] struct Input { @@ -12,18 +12,9 @@ fn test(db: &dyn salsa::Database, input: Input) -> Vec { .collect() } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn invoke() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, 3); let x: &Vec = test(db, input); expect_test::expect![[r#" diff --git a/tests/tracked_method.rs b/tests/tracked_method.rs index 10bc5ad6..0291a748 100644 --- a/tests/tracked_method.rs +++ b/tests/tracked_method.rs @@ -34,16 +34,7 @@ impl TrackedTrait for MyInput { #[test] fn execute() { - #[salsa::db] - #[derive(Default)] - struct Database { - storage: salsa::Storage, - } - - #[salsa::db] - impl salsa::Database for Database {} - - let mut db = Database::default(); + let mut db = salsa::DatabaseImpl::new(); let object = MyInput::new(&mut db, 22); // assert_eq!(object.tracked_fn(&db), 44); // assert_eq!(*object.tracked_fn_ref(&db), 66); diff --git a/tests/tracked_method_inherent_return_ref.rs b/tests/tracked_method_inherent_return_ref.rs index e6fa65a8..462a24da 100644 --- a/tests/tracked_method_inherent_return_ref.rs +++ b/tests/tracked_method_inherent_return_ref.rs @@ -1,4 +1,4 @@ -use salsa::Database as _; +use salsa::Database; #[salsa::input] struct Input { @@ -15,18 +15,9 @@ impl Input { } } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn invoke() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, 3); let x: &Vec = input.test(db); expect_test::expect![[r#" diff --git a/tests/tracked_method_on_tracked_struct.rs b/tests/tracked_method_on_tracked_struct.rs index b50bd9d6..1febcfd3 100644 --- a/tests/tracked_method_on_tracked_struct.rs +++ b/tests/tracked_method_on_tracked_struct.rs @@ -43,7 +43,7 @@ impl<'db1> ItemName<'db1> for SourceTree<'db1> { #[test] fn test_inherent() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, "foo".to_string()); let source_tree = input.source_tree(db); expect_test::expect![[r#" @@ -55,7 +55,7 @@ fn test_inherent() { #[test] fn test_trait() { - salsa::default_database().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, "foo".to_string()); let source_tree = input.source_tree(db); expect_test::expect![[r#" diff --git a/tests/tracked_method_trait_return_ref.rs b/tests/tracked_method_trait_return_ref.rs index 80d4035c..3c9fa5cc 100644 --- a/tests/tracked_method_trait_return_ref.rs +++ b/tests/tracked_method_trait_return_ref.rs @@ -1,4 +1,4 @@ -use salsa::Database as _; +use salsa::Database; #[salsa::input] struct Input { @@ -19,18 +19,9 @@ impl Trait for Input { } } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn invoke() { - Database::default().attach(|db| { + salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, 3); let x: &Vec = input.test(db); expect_test::expect![[r#" diff --git a/tests/tracked_struct_db1_lt.rs b/tests/tracked_struct_db1_lt.rs index 6931ea7e..e5de757c 100644 --- a/tests/tracked_struct_db1_lt.rs +++ b/tests/tracked_struct_db1_lt.rs @@ -20,14 +20,5 @@ struct MyTracked2<'db2> { field: u32, } -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[test] fn create_db() {} diff --git a/tests/tracked_with_intern.rs b/tests/tracked_with_intern.rs index 508a93c1..a8a72e8c 100644 --- a/tests/tracked_with_intern.rs +++ b/tests/tracked_with_intern.rs @@ -3,15 +3,6 @@ use test_log::test; -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[salsa::input] struct MyInput { field: String, diff --git a/tests/tracked_with_struct_db.rs b/tests/tracked_with_struct_db.rs index 232538c7..236edede 100644 --- a/tests/tracked_with_struct_db.rs +++ b/tests/tracked_with_struct_db.rs @@ -1,18 +1,9 @@ //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. -use salsa::Database as _; +use salsa::{Database, DatabaseImpl}; use test_log::test; -#[salsa::db] -#[derive(Default)] -struct Database { - storage: salsa::Storage, -} - -#[salsa::db] -impl salsa::Database for Database {} - #[salsa::input] struct MyInput { field: String, @@ -31,7 +22,7 @@ enum MyList<'db> { } #[salsa::tracked] -fn create_tracked_list(db: &dyn salsa::Database, input: MyInput) -> MyTracked<'_> { +fn create_tracked_list(db: &dyn Database, input: MyInput) -> MyTracked<'_> { let t0 = MyTracked::new(db, input, MyList::None); let t1 = MyTracked::new(db, input, MyList::Next(t0)); t1 @@ -39,7 +30,7 @@ fn create_tracked_list(db: &dyn salsa::Database, input: MyInput) -> MyTracked<'_ #[test] fn execute() { - Database::default().attach(|db| { + DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, "foo".to_string()); let t0: MyTracked = create_tracked_list(db, input); let t1 = create_tracked_list(db, input); From 62f158742c0ad2c04793eddaa23220b5a01f47dc Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 19:08:34 +0000 Subject: [PATCH 07/29] rename Storage to ZalsaImpl, privatize --- src/database.rs | 12 ++++++++---- src/lib.rs | 2 -- src/storage.rs | 8 ++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/database.rs b/src/database.rs index 52f74e65..55a2b6d6 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,10 @@ use std::{any::Any, panic::RefUnwindSafe}; -use crate::{self as salsa, local_state, storage::Zalsa, Durability, Event, Revision, Storage}; +use crate::{ + self as salsa, local_state, + storage::{Zalsa, ZalsaImpl}, + Durability, Event, Revision, +}; /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]` procedural macro. @@ -98,7 +102,7 @@ impl dyn Database { /// Concrete implementation of the [`Database`][] trait. /// Takes an optional type parameter `U` that allows you to thread your own data. pub struct DatabaseImpl { - storage: Storage, + storage: ZalsaImpl, } impl Default for DatabaseImpl { @@ -113,7 +117,7 @@ impl DatabaseImpl<()> { /// You can also use the [`Default`][] trait if your userdata implements it. pub fn new() -> Self { Self { - storage: Storage::with(()), + storage: ZalsaImpl::with(()), } } } @@ -124,7 +128,7 @@ impl DatabaseImpl { /// You can also use the [`Default`][] trait if your userdata implements it. pub fn with(u: U) -> Self { Self { - storage: Storage::with(u), + storage: ZalsaImpl::with(u), } } } diff --git a/src/lib.rs b/src/lib.rs index 7b38a291..098cd9d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,6 @@ pub use self::input::setter::Setter; pub use self::key::DatabaseKeyIndex; pub use self::revision::Revision; pub use self::runtime::Runtime; -pub use self::storage::Storage; pub use self::update::Update; pub use crate::local_state::with_attached_database; pub use salsa_macros::accumulator; @@ -89,7 +88,6 @@ pub mod plumbing { pub use crate::storage::views; pub use crate::storage::IngredientCache; pub use crate::storage::IngredientIndex; - pub use crate::storage::Storage; pub use crate::storage::Zalsa; pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::always_update; diff --git a/src/storage.rs b/src/storage.rs index f05c22f8..43ac7670 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -72,7 +72,7 @@ pub trait Zalsa { fn report_tracked_write(&mut self, durability: Durability); } -impl Zalsa for Storage { +impl Zalsa for ZalsaImpl { fn views(&self) -> &Views { &self.views_of } @@ -212,7 +212,7 @@ impl IngredientIndex { /// The "storage" struct stores all the data for the jars. /// It is shared between the main database and any active snapshots. -pub struct Storage { +pub(crate) struct ZalsaImpl { user_data: U, views_of: ViewsOf>, @@ -241,14 +241,14 @@ pub struct Storage { } // ANCHOR: default -impl Default for Storage { +impl Default for ZalsaImpl { fn default() -> Self { Self::with(Default::default()) } } // ANCHOR_END: default -impl Storage { +impl ZalsaImpl { pub(crate) fn with(user_data: U) -> Self { Self { views_of: Default::default(), From 138ca4b1f31bc5980adc7686e8287b18e41741ac Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 19:57:47 +0000 Subject: [PATCH 08/29] merge handle into the database Separate handles are no longer needed. --- .../src/setup_input_struct.rs | 3 +- src/database.rs | 97 ++++++++++++-- src/handle.rs | 125 ------------------ src/lib.rs | 2 - src/storage.rs | 47 +++---- tests/parallel/parallel_cancellation.rs | 9 +- tests/parallel/parallel_cycle_all_recover.rs | 9 +- tests/parallel/parallel_cycle_mid_recover.rs | 10 +- tests/parallel/parallel_cycle_none_recover.rs | 13 +- tests/parallel/parallel_cycle_one_recover.rs | 10 +- tests/preverify-struct-with-leaked-data.rs | 1 + 11 files changed, 131 insertions(+), 195 deletions(-) delete mode 100644 src/handle.rs diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index ccdc60c9..d89d63d6 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -89,7 +89,8 @@ macro_rules! setup_input_struct { pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, $zalsa::Revision) { let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); - let (ingredient, current_revision) = zalsa_mut.lookup_ingredient_mut(index); + let current_revision = zalsa_mut.current_revision(); + let ingredient = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); (ingredient, current_revision) } diff --git a/src/database.rs b/src/database.rs index 55a2b6d6..45c5c515 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,9 +1,11 @@ -use std::{any::Any, panic::RefUnwindSafe}; +use std::{any::Any, panic::RefUnwindSafe, sync::Arc}; + +use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, local_state, storage::{Zalsa, ZalsaImpl}, - Durability, Event, Revision, + Durability, Event, EventKind, Revision, }; /// The trait implemented by all Salsa databases. @@ -34,7 +36,6 @@ pub unsafe trait Database: AsDynDatabase + Any { /// is owned by the current thread, this could trigger deadlock. fn synthetic_write(&mut self, durability: Durability) { let zalsa_mut = self.zalsa_mut(); - zalsa_mut.new_revision(); zalsa_mut.report_tracked_write(durability); } @@ -57,10 +58,14 @@ pub unsafe trait Database: AsDynDatabase + Any { local_state::attach(self, |_state| op(self)) } - /// Plumbing methods. + /// Plumbing method: Access the internal salsa methods. #[doc(hidden)] fn zalsa(&self) -> &dyn Zalsa; + /// Plumbing method: Access the internal salsa methods for mutating the database. + /// + /// **WARNING:** Triggers a new revision, canceling other database handles. + /// This can lead to deadlock! #[doc(hidden)] fn zalsa_mut(&mut self) -> &mut dyn Zalsa; } @@ -102,7 +107,11 @@ impl dyn Database { /// Concrete implementation of the [`Database`][] trait. /// Takes an optional type parameter `U` that allows you to thread your own data. pub struct DatabaseImpl { - storage: ZalsaImpl, + /// Reference to the database. This is always `Some` except during destruction. + zalsa_impl: Option>>, + + /// Coordination data. + coordinate: Arc, } impl Default for DatabaseImpl { @@ -116,9 +125,7 @@ impl DatabaseImpl<()> { /// /// You can also use the [`Default`][] trait if your userdata implements it. pub fn new() -> Self { - Self { - storage: ZalsaImpl::with(()), - } + Self::with(()) } } @@ -128,16 +135,47 @@ impl DatabaseImpl { /// You can also use the [`Default`][] trait if your userdata implements it. pub fn with(u: U) -> Self { Self { - storage: ZalsaImpl::with(u), + zalsa_impl: Some(Arc::new(ZalsaImpl::with(u))), + coordinate: Arc::new(Coordinate { + clones: Mutex::new(1), + cvar: Default::default(), + }), + } + } + + fn zalsa_impl(&self) -> &Arc> { + self.zalsa_impl.as_ref().unwrap() + } + + // ANCHOR: cancel_other_workers + /// Sets cancellation flag and blocks until all other workers with access + /// to this storage have completed. + /// + /// This could deadlock if there is a single worker with two handles to the + /// same database! + fn cancel_others(&mut self) { + let zalsa = self.zalsa_impl(); + zalsa.set_cancellation_flag(); + + self.salsa_event(&|| Event { + thread_id: std::thread::current().id(), + + kind: EventKind::DidSetCancellationFlag, + }); + + let mut clones = self.coordinate.clones.lock(); + while *clones != 1 { + self.coordinate.cvar.wait(&mut clones); } } + // ANCHOR_END: cancel_other_workers } impl std::ops::Deref for DatabaseImpl { type Target = U; fn deref(&self) -> &U { - &self.storage.user_data() + self.zalsa_impl().user_data() } } @@ -146,11 +184,17 @@ impl RefUnwindSafe for DatabaseImpl {} #[salsa_macros::db] unsafe impl Database for DatabaseImpl { fn zalsa(&self) -> &dyn Zalsa { - &self.storage + &**self.zalsa_impl() } fn zalsa_mut(&mut self) -> &mut dyn Zalsa { - &mut self.storage + self.cancel_others(); + + // The ref count on the `Arc` should now be 1 + let arc_zalsa_mut = self.zalsa_impl.as_mut().unwrap(); + let zalsa_mut = Arc::get_mut(arc_zalsa_mut).unwrap(); + zalsa_mut.new_revision(); + zalsa_mut } // Report a salsa event. @@ -159,6 +203,28 @@ unsafe impl Database for DatabaseImpl { } } +impl Clone for DatabaseImpl { + fn clone(&self) -> Self { + *self.coordinate.clones.lock() += 1; + + Self { + zalsa_impl: self.zalsa_impl.clone(), + coordinate: Arc::clone(&self.coordinate), + } + } +} + +impl Drop for DatabaseImpl { + fn drop(&mut self) { + // Drop the database handle *first* + self.zalsa_impl.take(); + + // *Now* decrement the number of clones and notify once we have completed + *self.coordinate.clones.lock() -= 1; + self.coordinate.cvar.notify_all(); + } +} + pub trait UserData: Any + Sized { /// Callback invoked by the [`Database`][] at key points during salsa execution. /// By overriding this method, you can inject logging or other custom behavior. @@ -174,3 +240,10 @@ pub trait UserData: Any + Sized { } impl UserData for () {} + +struct Coordinate { + /// Counter of the number of clones of actor. Begins at 1. + /// Incremented when cloned, decremented when dropped. + clones: Mutex, + cvar: Condvar, +} diff --git a/src/handle.rs b/src/handle.rs deleted file mode 100644 index e3c2ecd9..00000000 --- a/src/handle.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::sync::Arc; - -use parking_lot::{Condvar, Mutex}; - -use crate::{Database, Event, EventKind}; - -/// A database "handle" allows coordination of multiple async tasks accessing the same database. -/// So long as you are just doing reads, you can freely clone. -/// When you attempt to modify the database, you call `get_mut`, which will set the cancellation flag, -/// causing other handles to get panics. Once all other handles are dropped, you can proceed. -pub struct Handle { - /// Reference to the database. This is always `Some` except during destruction. - db: Option>, - - /// Coordination data. - coordinate: Arc, -} - -struct Coordinate { - /// Counter of the number of clones of actor. Begins at 1. - /// Incremented when cloned, decremented when dropped. - clones: Mutex, - cvar: Condvar, -} - -impl Handle { - /// Create a new handle wrapping `db`. - pub fn new(db: Db) -> Self { - Self { - db: Some(Arc::new(db)), - coordinate: Arc::new(Coordinate { - clones: Mutex::new(1), - cvar: Default::default(), - }), - } - } - - fn db(&self) -> &Arc { - self.db.as_ref().unwrap() - } - - fn db_mut(&mut self) -> &mut Arc { - self.db.as_mut().unwrap() - } - - /// Returns a mutable reference to the inner database. - /// If other handles are active, this method sets the cancellation flag - /// and blocks until they are dropped. - pub fn get_mut(&mut self) -> &mut Db { - self.cancel_others(); - - // Once cancellation above completes, the other handles are being dropped. - // However, because the signal is sent before the destructor completes, it's - // possible that they have not *yet* dropped. - // - // Therefore, we may have to do a (short) bit of - // spinning before we observe the thread-count reducing to 0. - // - // An alternative would be to - Arc::get_mut(self.db_mut()).expect("other threads remain active despite cancellation") - } - - /// Returns the inner database, consuming the handle. - /// - /// If other handles are active, this method sets the cancellation flag - /// and blocks until they are dropped. - pub fn into_inner(mut self) -> Db { - self.cancel_others(); - Arc::into_inner(self.db.take().unwrap()) - .expect("other threads remain active despite cancellation") - } - - // ANCHOR: cancel_other_workers - /// Sets cancellation flag and blocks until all other workers with access - /// to this storage have completed. - /// - /// This could deadlock if there is a single worker with two handles to the - /// same database! - fn cancel_others(&mut self) { - let zalsa = self.db().zalsa(); - zalsa.set_cancellation_flag(); - - self.db().salsa_event(&|| Event { - thread_id: std::thread::current().id(), - - kind: EventKind::DidSetCancellationFlag, - }); - - let mut clones = self.coordinate.clones.lock(); - while *clones != 1 { - self.coordinate.cvar.wait(&mut clones); - } - } - // ANCHOR_END: cancel_other_workers -} - -impl Drop for Handle { - fn drop(&mut self) { - // Drop the database handle *first* - self.db.take(); - - // *Now* decrement the number of clones and notify once we have completed - *self.coordinate.clones.lock() -= 1; - self.coordinate.cvar.notify_all(); - } -} - -impl std::ops::Deref for Handle { - type Target = Db; - - fn deref(&self) -> &Self::Target { - self.db() - } -} - -impl Clone for Handle { - fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - - Self { - db: Some(Arc::clone(self.db())), - coordinate: Arc::clone(&self.coordinate), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 098cd9d9..8aa9685e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,6 @@ mod database; mod durability; mod event; mod function; -mod handle; mod hash; mod id; mod ingredient; @@ -36,7 +35,6 @@ pub use self::database::UserData; pub use self::durability::Durability; pub use self::event::Event; pub use self::event::EventKind; -pub use self::handle::Handle; pub use self::id::Id; pub use self::input::setter::Setter; pub use self::key::DatabaseKeyIndex; diff --git a/src/storage.rs b/src/storage.rs index 43ac7670..77fc6210 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -41,24 +41,16 @@ pub trait Zalsa { fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient; /// Gets an `&mut`-ref to an ingredient by index. - /// - /// **Triggers a new revision.** Returns the `&mut` reference - /// along with the new revision index. fn lookup_ingredient_mut( &mut self, index: IngredientIndex, - ) -> (&mut dyn Ingredient, Revision); + ) -> &mut dyn Ingredient; fn runtimex(&self) -> &Runtime; /// Return the current revision fn current_revision(&self) -> Revision; - /// Increment revision counter. - /// - /// **Triggers a new revision.** - fn new_revision(&mut self) -> Revision; - /// Return the time when an input of durability `durability` last changed fn last_changed_revision(&self, durability: Durability) -> Revision; @@ -126,22 +118,10 @@ impl Zalsa for ZalsaImpl { fn lookup_ingredient_mut( &mut self, index: IngredientIndex, - ) -> (&mut dyn Ingredient, Revision) { - let new_revision = self.runtime.new_revision(); - - for index in self.ingredients_requiring_reset.iter() { - self.ingredients_vec - .get_mut(index.as_usize()) - .unwrap() - .reset_for_new_revision(); - } - - ( - &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap(), - new_revision, - ) + ) -> &mut dyn Ingredient { + &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap() } - + fn current_revision(&self) -> Revision { self.runtime.current_revision() } @@ -165,10 +145,6 @@ impl Zalsa for ZalsaImpl { fn set_cancellation_flag(&self) { self.runtime.set_cancellation_flag() } - - fn new_revision(&mut self) -> Revision { - self.runtime.new_revision() - } } /// Nonce type representing the underlying database storage. @@ -264,6 +240,21 @@ impl ZalsaImpl { pub(crate) fn user_data(&self) -> &U { &self.user_data } + + /// Triggers a new revision. Invoked automatically when you call `zalsa_mut` + /// and so doesn't need to be called otherwise. + pub(crate) fn new_revision(&mut self) -> Revision { + let new_revision = self.runtime.new_revision(); + + for index in self.ingredients_requiring_reset.iter() { + self.ingredients_vec + .get_mut(index.as_usize()) + .unwrap() + .reset_for_new_revision(); + } + + new_revision + } } /// Caches a pointer to an ingredient in a database. diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index 0e35ab25..55f81e8d 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -4,7 +4,6 @@ use salsa::Cancelled; use salsa::DatabaseImpl; -use salsa::Handle; use salsa::Setter; use crate::setup::Knobs; @@ -44,17 +43,17 @@ fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { #[test] fn execute() { - let mut db = Handle::new(>::default()); + let mut db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); - input.set_field(db.get_mut()).to(2); + input.set_field(&mut db).to(2); // Assert thread A *should* was cancelled let cancelled = thread_a diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs index 7706d6ec..ac20b504 100644 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -3,7 +3,6 @@ //! both intra and cross thread. use salsa::DatabaseImpl; -use salsa::Handle; use crate::setup::Knobs; use crate::setup::KnobsDatabase; @@ -87,19 +86,19 @@ fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i3 #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b1(&*db, input) + move || b1(&db, input) }); assert_eq!(thread_a.join().unwrap(), 11); diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs index 0c5e3475..8bca2f61 100644 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -2,7 +2,7 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::{DatabaseImpl, Handle}; +use salsa::DatabaseImpl; use crate::setup::{Knobs, KnobsDatabase}; @@ -81,19 +81,19 @@ fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i3 #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b1(&*db, input) + move || b1(&db, input) }); // We expect that the recovery function yields diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs index 39b6299c..d74aa5b0 100644 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -5,9 +5,8 @@ use crate::setup::Knobs; use crate::setup::KnobsDatabase; use expect_test::expect; -use salsa::Database as _; +use salsa::Database; use salsa::DatabaseImpl; -use salsa::Handle; #[salsa::input] pub(crate) struct MyInput { @@ -38,19 +37,19 @@ pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, -1); + let input = MyInput::new(&db, -1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a(&*db, input) + move || a(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b(&*db, input) + move || b(&db, input) }); // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). @@ -64,7 +63,7 @@ fn execute() { b(0), ] "#]]; - expected.assert_debug_eq(&c.all_participants(&*db)); + expected.assert_debug_eq(&c.all_participants(&db)); } else { panic!("b failed in an unexpected way: {:?}", err_b); } diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs index 7a32d95c..2bf53857 100644 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ b/tests/parallel/parallel_cycle_one_recover.rs @@ -2,7 +2,7 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::{DatabaseImpl, Handle}; +use salsa::DatabaseImpl; use crate::setup::{Knobs, KnobsDatabase}; @@ -70,19 +70,19 @@ pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b1(&*db, input) + move || b1(&db, input) }); // We expect that the recovery function yields diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 2c5bdfd5..99391709 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -66,6 +66,7 @@ fn test_leaked_inputs_ignored() { let result_in_rev_2 = function(&db, input); db.assert_logs(expect![[r#" [ + "Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }", "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(0) } }", ]"#]]); From 8e9ebbafd31fa74ef8666913b63eaa9f282dc055 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 20:06:02 +0000 Subject: [PATCH 09/29] improve comments --- src/database.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/database.rs b/src/database.rs index 45c5c515..e79e94ea 100644 --- a/src/database.rs +++ b/src/database.rs @@ -110,7 +110,8 @@ pub struct DatabaseImpl { /// Reference to the database. This is always `Some` except during destruction. zalsa_impl: Option>>, - /// Coordination data. + /// Coordination data for cancellation of other handles when `zalsa_mut` is called. + /// This could be stored in ZalsaImpl but it makes things marginally cleaner to keep it separate. coordinate: Arc, } @@ -143,6 +144,9 @@ impl DatabaseImpl { } } + /// Access the `Arc`. This should always be + /// possible as `zalsa_impl` only becomes + /// `None` once we are in the `Drop` impl. fn zalsa_impl(&self) -> &Arc> { self.zalsa_impl.as_ref().unwrap() } From a675810edfd2638a505d2ae0a76acb4988d4a141 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 01:48:20 +0000 Subject: [PATCH 10/29] move local-state into DatabaseImpl Each clone gets an independent local state. --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 32 ++-- src/accumulator.rs | 48 +++-- src/attach.rs | 100 ++++++++++ src/cycle.rs | 4 +- src/database.rs | 27 ++- src/function/accumulated.rs | 73 ++++--- src/function/fetch.rs | 37 ++-- src/function/maybe_changed_after.rs | 52 +++-- src/function/specify.rs | 126 ++++++------ src/input.rs | 31 ++- src/interned.rs | 75 ++++---- src/key.rs | 4 +- src/lib.rs | 6 +- src/local_state.rs | 81 +------- src/storage.rs | 10 +- src/tracked_struct.rs | 180 +++++++++--------- src/tracked_struct/tracked_field.rs | 36 ++-- src/views.rs | 35 +--- tests/accumulate.rs | 2 +- tests/common/mod.rs | 4 +- 20 files changed, 475 insertions(+), 488 deletions(-) create mode 100644 src/attach.rs diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index e1165516..58d88875 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -265,24 +265,26 @@ macro_rules! setup_tracked_fn { } } } - let result = $zalsa::macro_if! { - if $needs_interner { - { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); - $Configuration::fn_ingredient($db).fetch($db, key) + $zalsa::attach($db, || { + let result = $zalsa::macro_if! { + if $needs_interner { + { + let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); + $Configuration::fn_ingredient($db).fetch($db, key) + } + } else { + $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) } - } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) - } - }; + }; - $zalsa::macro_if! { - if $return_ref { - result - } else { - <$output_ty as std::clone::Clone>::clone(result) + $zalsa::macro_if! { + if $return_ref { + result + } else { + <$output_ty as std::clone::Clone>::clone(result) + } } - } + }) } }; } diff --git a/src/accumulator.rs b/src/accumulator.rs index dc114bb2..47133790 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,7 +10,7 @@ use crate::{ hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, - local_state::{self, LocalState, QueryOrigin}, + local_state::{LocalState, QueryOrigin}, storage::IngredientIndex, Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; @@ -80,32 +80,30 @@ impl IngredientImpl { } pub fn push(&self, db: &dyn crate::Database, value: A) { - local_state::attach(db, |state| { - let current_revision = db.zalsa().current_revision(); - let (active_query, _) = match state.active_query() { - Some(pair) => pair, - None => { - panic!("cannot accumulate values outside of an active query") - } - }; - - let mut accumulated_values = - self.map.entry(active_query).or_insert(AccumulatedValues { - values: vec![], - produced_at: current_revision, - }); - - // When we call `push' in a query, we will add the accumulator to the output of the query. - // If we find here that this accumulator is not the output of the query, - // we can say that the accumulated values we stored for this query is out of date. - if !state.is_output_of_active_query(self.dependency_index()) { - accumulated_values.values.truncate(0); - accumulated_values.produced_at = current_revision; + let state = db.zalsa_local(); + let current_revision = db.zalsa().current_revision(); + let (active_query, _) = match state.active_query() { + Some(pair) => pair, + None => { + panic!("cannot accumulate values outside of an active query") } + }; + + let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues { + values: vec![], + produced_at: current_revision, + }); + + // When we call `push' in a query, we will add the accumulator to the output of the query. + // If we find here that this accumulator is not the output of the query, + // we can say that the accumulated values we stored for this query is out of date. + if !state.is_output_of_active_query(self.dependency_index()) { + accumulated_values.values.truncate(0); + accumulated_values.produced_at = current_revision; + } - state.add_output(self.dependency_index()); - accumulated_values.values.push(value); - }) + state.add_output(self.dependency_index()); + accumulated_values.values.push(value); } pub(crate) fn produced_by( diff --git a/src/attach.rs b/src/attach.rs new file mode 100644 index 00000000..3dcf9d11 --- /dev/null +++ b/src/attach.rs @@ -0,0 +1,100 @@ +use std::{cell::Cell, ptr::NonNull}; + +use crate::Database; + +thread_local! { + /// The thread-local state salsa requires for a given thread + static ATTACHED: Attached = const { Attached::new() } +} + +/// State that is specific to a single execution thread. +/// +/// Internally, this type uses ref-cells. +/// +/// **Note also that all mutations to the database handle (and hence +/// to the local-state) must be undone during unwinding.** +struct Attached { + /// Pointer to the currently attached database. + database: Cell>>, +} + +impl Attached { + const fn new() -> Self { + Self { + database: Cell::new(None), + } + } + + fn attach(&self, db: &Db, op: impl FnOnce() -> R) -> R + where + Db: ?Sized + Database, + { + struct DbGuard<'s> { + state: Option<&'s Attached>, + } + + impl<'s> DbGuard<'s> { + fn new(attached: &'s Attached, db: &dyn Database) -> Self { + if let Some(current_db) = attached.database.get() { + let new_db = NonNull::from(db); + + // Already attached? Assert that the database has not changed. + // NOTE: It's important to use `addr_eq` here because `NonNull::eq` + // not only compares the address but also the type's metadata. + if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) { + panic!( + "Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}", + ); + } + + Self { state: None } + } else { + // Otherwise, set the database. + attached.database.set(Some(NonNull::from(db))); + Self { + state: Some(attached), + } + } + } + } + + impl Drop for DbGuard<'_> { + fn drop(&mut self) { + // Reset database to null if we did anything in `DbGuard::new`. + if let Some(attached) = self.state { + attached.database.set(None); + } + } + } + + let _guard = DbGuard::new(self, db.as_dyn_database()); + op() + } + + /// Access the "attached" database. Returns `None` if no database is attached. + /// Databases are attached with `attach_database`. + fn with(&self, op: impl FnOnce(&dyn Database) -> R) -> Option { + if let Some(db) = self.database.get() { + // SAFETY: We always attach the database in for the entire duration of a function, + // so it cannot become "unattached" while this function is running. + Some(op(unsafe { db.as_ref() })) + } else { + None + } + } +} + +/// Attach the database to the current thread and execute `op`. +/// Panics if a different database has already been attached. +pub fn attach(db: &Db, op: impl FnOnce() -> R) -> R +where + Db: ?Sized + Database, +{ + ATTACHED.with(|a| a.attach(db, op)) +} + +/// Access the "attached" database. Returns `None` if no database is attached. +/// Databases are attached with `attach_database`. +pub fn with_attached_database(op: impl FnOnce(&dyn Database) -> R) -> Option { + ATTACHED.with(|a| a.with(op)) +} diff --git a/src/cycle.rs b/src/cycle.rs index 4a8a56f4..44558b4a 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,4 +1,4 @@ -use crate::{key::DatabaseKeyIndex, local_state, Database}; +use crate::{key::DatabaseKeyIndex, Database}; use std::{panic::AssertUnwindSafe, sync::Arc}; /// Captures the participants of a cycle that occurred when executing a query. @@ -74,7 +74,7 @@ impl Cycle { impl std::fmt::Debug for Cycle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - local_state::with_attached_database(|db| { + crate::attach::with_attached_database(|db| { f.debug_struct("UnexpectedCycle") .field("all_participants", &self.all_participants(db)) .field("unexpected_participants", &self.unexpected_participants(db)) diff --git a/src/database.rs b/src/database.rs index e79e94ea..3643c6dd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,7 +3,8 @@ use std::{any::Any, panic::RefUnwindSafe, sync::Arc}; use parking_lot::{Condvar, Mutex}; use crate::{ - self as salsa, local_state, + self as salsa, + local_state::{self, LocalState}, storage::{Zalsa, ZalsaImpl}, Durability, Event, EventKind, Revision, }; @@ -16,7 +17,7 @@ use crate::{ /// This trait can only safely be implemented by Salsa's [`DatabaseImpl`][] type. /// FIXME: Document better the unsafety conditions we guarantee. #[salsa_macros::db] -pub unsafe trait Database: AsDynDatabase + Any { +pub unsafe trait Database: Send + AsDynDatabase + Any { /// This function is invoked by the salsa runtime at various points during execution. /// You can customize what happens by implementing the [`UserData`][] trait. /// By default, the event is logged at level debug using tracing facade. @@ -45,9 +46,8 @@ pub unsafe trait Database: AsDynDatabase + Any { /// revision. fn report_untracked_read(&self) { let db = self.as_dyn_database(); - local_state::attach(db, |state| { - state.report_untracked_read(db.zalsa().current_revision()) - }) + let zalsa_local = db.zalsa_local(); + zalsa_local.report_untracked_read(db.zalsa().current_revision()) } /// Execute `op` with the database in thread-local storage for debug print-outs. @@ -55,7 +55,7 @@ pub unsafe trait Database: AsDynDatabase + Any { where Self: Sized, { - local_state::attach(self, |_state| op(self)) + crate::attach::attach(self, || op(self)) } /// Plumbing method: Access the internal salsa methods. @@ -68,6 +68,10 @@ pub unsafe trait Database: AsDynDatabase + Any { /// This can lead to deadlock! #[doc(hidden)] fn zalsa_mut(&mut self) -> &mut dyn Zalsa; + + /// Access the thread-local state associated with this database + #[doc(hidden)] + fn zalsa_local(&self) -> &LocalState; } /// Upcast to a `dyn Database`. @@ -113,6 +117,9 @@ pub struct DatabaseImpl { /// Coordination data for cancellation of other handles when `zalsa_mut` is called. /// This could be stored in ZalsaImpl but it makes things marginally cleaner to keep it separate. coordinate: Arc, + + /// Per-thread state + zalsa_local: local_state::LocalState, } impl Default for DatabaseImpl { @@ -141,6 +148,7 @@ impl DatabaseImpl { clones: Mutex::new(1), cvar: Default::default(), }), + zalsa_local: LocalState::new(), } } @@ -201,6 +209,10 @@ unsafe impl Database for DatabaseImpl { zalsa_mut } + fn zalsa_local(&self) -> &LocalState { + &self.zalsa_local + } + // Report a salsa event. fn salsa_event(&self, event: &dyn Fn() -> Event) { U::salsa_event(self, event) @@ -214,6 +226,7 @@ impl Clone for DatabaseImpl { Self { zalsa_impl: self.zalsa_impl.clone(), coordinate: Arc::clone(&self.coordinate), + zalsa_local: LocalState::new(), } } } @@ -229,7 +242,7 @@ impl Drop for DatabaseImpl { } } -pub trait UserData: Any + Sized { +pub trait UserData: Any + Sized + Send + Sync { /// Callback invoked by the [`Database`][] at key points during salsa execution. /// By overriding this method, you can inject logging or other custom behavior. /// diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 98d11eaa..62c85930 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,4 @@ -use crate::{accumulator, hash::FxHashSet, local_state, Database, DatabaseKeyIndex, Id}; +use crate::{accumulator, hash::FxHashSet, Database, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -12,44 +12,41 @@ where where A: accumulator::Accumulator, { - local_state::attach(db, |local_state| { - let zalsa = db.zalsa(); - let current_revision = zalsa.current_revision(); - - let Some(accumulator) = >::from_db(db) else { - return vec![]; - }; - let mut output = vec![]; - - // First ensure the result is up to date - self.fetch(db, key); - - let db_key = self.database_key_index(key); - let mut visited: FxHashSet = FxHashSet::default(); - let mut stack: Vec = vec![db_key]; - - while let Some(k) = stack.pop() { - if visited.insert(k) { - accumulator.produced_by(current_revision, local_state, k, &mut output); - - let origin = zalsa - .lookup_ingredient(k.ingredient_index) - .origin(k.key_index); - let inputs = origin.iter().flat_map(|origin| origin.inputs()); - // Careful: we want to push in execution order, so reverse order to - // ensure the first child that was executed will be the first child popped - // from the stack. - stack.extend( - inputs - .flat_map(|input| { - TryInto::::try_into(input).into_iter() - }) - .rev(), - ); - } + let zalsa = db.zalsa(); + let zalsa_local = db.zalsa_local(); + let current_revision = zalsa.current_revision(); + + let Some(accumulator) = >::from_db(db) else { + return vec![]; + }; + let mut output = vec![]; + + // First ensure the result is up to date + self.fetch(db, key); + + let db_key = self.database_key_index(key); + let mut visited: FxHashSet = FxHashSet::default(); + let mut stack: Vec = vec![db_key]; + + while let Some(k) = stack.pop() { + if visited.insert(k) { + accumulator.produced_by(current_revision, zalsa_local, k, &mut output); + + let origin = zalsa + .lookup_ingredient(k.ingredient_index) + .origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + // Careful: we want to push in execution order, so reverse order to + // ensure the first child that was executed will be the first child popped + // from the stack. + stack.extend( + inputs + .flat_map(|input| TryInto::::try_into(input).into_iter()) + .rev(), + ); } + } - output - }) + output } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f204145f..4e3018b9 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,9 +1,7 @@ use arc_swap::Guard; use crate::{ - local_state::{self, LocalState}, - runtime::StampedValue, - AsDynDatabase as _, Database as _, Id, + local_state::LocalState, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, }; use super::{Configuration, IngredientImpl}; @@ -13,27 +11,26 @@ where C: Configuration, { pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Output<'db> { - local_state::attach(db.as_dyn_database(), |local_state| { - local_state.unwind_if_revision_cancelled(db.as_dyn_database()); + let zalsa_local = db.zalsa_local(); + zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); - let StampedValue { - value, - durability, - changed_at, - } = self.compute_value(db, local_state, key); + let StampedValue { + value, + durability, + changed_at, + } = self.compute_value(db, zalsa_local, key); - if let Some(evicted) = self.lru.record_use(key) { - self.evict(evicted); - } + if let Some(evicted) = self.lru.record_use(key) { + self.evict(evicted); + } - local_state.report_tracked_read( - self.database_key_index(key).into(), - durability, - changed_at, - ); + zalsa_local.report_tracked_read( + self.database_key_index(key).into(), + durability, + changed_at, + ); - value - }) + value } #[inline] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 15a677d5..0d3fc3d4 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,7 +2,7 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, - local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, + local_state::{ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, storage::Zalsa, AsDynDatabase as _, Database, Id, Revision, @@ -20,36 +20,32 @@ where key: Id, revision: Revision, ) -> bool { - local_state::attach(db.as_dyn_database(), |local_state| { - let zalsa = db.zalsa(); - local_state.unwind_if_revision_cancelled(db.as_dyn_database()); - - loop { - let database_key_index = self.database_key_index(key); - - tracing::debug!( - "{database_key_index:?}: maybe_changed_after(revision = {revision:?})" - ); - - // Check if we have a verified version: this is the hot path. - let memo_guard = self.memo_map.get(key); - if let Some(memo) = &memo_guard { - if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { - return memo.revisions.changed_at > revision; - } - drop(memo_guard); // release the arc-swap guard before cold path - if let Some(mcs) = self.maybe_changed_after_cold(db, local_state, key, revision) - { - return mcs; - } else { - // We failed to claim, have to retry. - } + let zalsa_local = db.zalsa_local(); + let zalsa = db.zalsa(); + zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); + + loop { + let database_key_index = self.database_key_index(key); + + tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})"); + + // Check if we have a verified version: this is the hot path. + let memo_guard = self.memo_map.get(key); + if let Some(memo) = &memo_guard { + if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { + return memo.revisions.changed_at > revision; + } + drop(memo_guard); // release the arc-swap guard before cold path + if let Some(mcs) = self.maybe_changed_after_cold(db, zalsa_local, key, revision) { + return mcs; } else { - // No memo? Assume has changed. - return true; + // We failed to claim, have to retry. } + } else { + // No memo? Assume has changed. + return true; } - }) + } } fn maybe_changed_after_cold<'db>( diff --git a/src/function/specify.rs b/src/function/specify.rs index d8d5dea8..98945dc5 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,7 +1,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ - local_state::{self, QueryOrigin, QueryRevisions}, + local_state::{QueryOrigin, QueryRevisions}, tracked_struct::TrackedStructInDb, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; @@ -18,76 +18,74 @@ where where C::Input<'db>: TrackedStructInDb, { - local_state::attach(db.as_dyn_database(), |state| { - let (active_query_key, current_deps) = match state.active_query() { - Some(v) => v, - None => panic!("can only use `specify` inside a tracked function"), - }; + let zalsa_local = db.zalsa_local(); - // `specify` only works if the key is a tracked struct created in the current query. - // - // The reason is this. We want to ensure that the same result is reached regardless of - // the "path" that the user takes through the execution graph. - // If you permit values to be specified from other queries, you can have a situation like this: - // * Q0 creates the tracked struct T0 - // * Q1 specifies the value for F(T0) - // * Q2 invokes F(T0) - // * Q3 invokes Q1 and then Q2 - // * Q4 invokes Q2 and then Q1 - // - // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. - let database_key_index = >::database_key_index(db.as_dyn_database(), key); - let dependency_index = database_key_index.into(); - if !state.is_output_of_active_query(dependency_index) { - panic!( - "can only use `specify` on salsa structs created during the current tracked fn" - ); - } + let (active_query_key, current_deps) = match zalsa_local.active_query() { + Some(v) => v, + None => panic!("can only use `specify` inside a tracked function"), + }; - // Subtle: we treat the "input" to a set query as if it were - // volatile. - // - // The idea is this. You have the current query C that - // created the entity E, and it is setting the value F(E) of the function F. - // When some other query R reads the field F(E), in order to have obtained - // the entity E, it has to have executed the query C. - // - // This will have forced C to either: - // - // - not create E this time, in which case R shouldn't have it (some kind of leak has occurred) - // - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately - // - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics). - // - // So, ruling out the case of a leak having occurred, that means that the reader R will either see: - // - // - a result that is verified in the current revision, because it was set, which will use the set value - // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) + // `specify` only works if the key is a tracked struct created in the current query. + // + // The reason is this. We want to ensure that the same result is reached regardless of + // the "path" that the user takes through the execution graph. + // If you permit values to be specified from other queries, you can have a situation like this: + // * Q0 creates the tracked struct T0 + // * Q1 specifies the value for F(T0) + // * Q2 invokes F(T0) + // * Q3 invokes Q1 and then Q2 + // * Q4 invokes Q2 and then Q1 + // + // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. + let database_key_index = >::database_key_index(db.as_dyn_database(), key); + let dependency_index = database_key_index.into(); + if !zalsa_local.is_output_of_active_query(dependency_index) { + panic!("can only use `specify` on salsa structs created during the current tracked fn"); + } - let revision = db.zalsa().current_revision(); - let mut revisions = QueryRevisions { - changed_at: current_deps.changed_at, - durability: current_deps.durability, - origin: QueryOrigin::Assigned(active_query_key), - }; + // Subtle: we treat the "input" to a set query as if it were + // volatile. + // + // The idea is this. You have the current query C that + // created the entity E, and it is setting the value F(E) of the function F. + // When some other query R reads the field F(E), in order to have obtained + // the entity E, it has to have executed the query C. + // + // This will have forced C to either: + // + // - not create E this time, in which case R shouldn't have it (some kind of leak has occurred) + // - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately + // - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics). + // + // So, ruling out the case of a leak having occurred, that means that the reader R will either see: + // + // - a result that is verified in the current revision, because it was set, which will use the set value + // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) - if let Some(old_memo) = self.memo_map.get(key) { - self.backdate_if_appropriate(&old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, &old_memo, &revisions); - } + let revision = db.zalsa().current_revision(); + let mut revisions = QueryRevisions { + changed_at: current_deps.changed_at, + durability: current_deps.durability, + origin: QueryOrigin::Assigned(active_query_key), + }; - let memo = Memo { - value: Some(value), - verified_at: AtomicCell::new(revision), - revisions, - }; + if let Some(old_memo) = self.memo_map.get(key) { + self.backdate_if_appropriate(&old_memo, &mut revisions, &value); + self.diff_outputs(db, database_key_index, &old_memo, &revisions); + } - tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key); - self.insert_memo(db, key, memo); + let memo = Memo { + value: Some(value), + verified_at: AtomicCell::new(revision), + revisions, + }; + + tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key); + self.insert_memo(db, key, memo); - // Record that the current query *specified* a value for this cell. - let database_key_index = self.database_key_index(key); - state.add_output(database_key_index.into()); - }) + // Record that the current query *specified* a value for this cell. + let database_key_index = self.database_key_index(key); + zalsa_local.add_output(database_key_index.into()); } /// Invoked when the query `executor` has been validated as having green inputs diff --git a/src/input.rs b/src/input.rs index 6fa683e6..62fe321c 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,7 +17,7 @@ use crate::{ id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::{self, QueryOrigin}, + local_state::QueryOrigin, plumbing::{Jar, Stamp}, storage::IngredientIndex, Database, Durability, Id, Revision, @@ -152,21 +152,20 @@ impl IngredientImpl { id: C::Struct, field_index: usize, ) -> &'db C::Fields { - local_state::attach(db, |state| { - let field_ingredient_index = self.ingredient_index.successor(field_index); - let id = id.as_id(); - let value = self.struct_map.get(id); - let stamp = &value.stamps[field_index]; - state.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(id), - }, - stamp.durability, - stamp.changed_at, - ); - &value.fields - }) + let zalsa_local = db.zalsa_local(); + let field_ingredient_index = self.ingredient_index.successor(field_index); + let id = id.as_id(); + let value = self.struct_map.get(id); + let stamp = &value.stamps[field_index]; + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(id), + }, + stamp.durability, + stamp.changed_at, + ); + &value.fields } /// Peek at the field values without recording any read dependency. diff --git a/src/interned.rs b/src/interned.rs index d3da16a0..55d9fc3b 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,7 +9,7 @@ use crate::durability::Durability; use crate::id::AsId; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; -use crate::local_state::{self, QueryOrigin}; +use crate::local_state::QueryOrigin; use crate::plumbing::Jar; use crate::storage::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id}; @@ -136,46 +136,45 @@ where db: &'db dyn crate::Database, data: C::Data<'db>, ) -> C::Struct<'db> { - local_state::attach(db, |state| { - state.report_tracked_read( - DependencyIndex::for_table(self.ingredient_index), - Durability::MAX, - self.reset_at, - ); - - // Optimisation to only get read lock on the map if the data has already - // been interned. - let internal_data = unsafe { self.to_internal_data(data) }; - if let Some(guard) = self.key_map.get(&internal_data) { - let id = *guard; - drop(guard); - return self.interned_value(id); + let zalsa_local = db.zalsa_local(); + zalsa_local.report_tracked_read( + DependencyIndex::for_table(self.ingredient_index), + Durability::MAX, + self.reset_at, + ); + + // Optimisation to only get read lock on the map if the data has already + // been interned. + let internal_data = unsafe { self.to_internal_data(data) }; + if let Some(guard) = self.key_map.get(&internal_data) { + let id = *guard; + drop(guard); + return self.interned_value(id); + } + + match self.key_map.entry(internal_data.clone()) { + // Data has been interned by a racing call, use that ID instead + dashmap::mapref::entry::Entry::Occupied(entry) => { + let id = *entry.get(); + drop(entry); + self.interned_value(id) } - match self.key_map.entry(internal_data.clone()) { - // Data has been interned by a racing call, use that ID instead - dashmap::mapref::entry::Entry::Occupied(entry) => { - let id = *entry.get(); - drop(entry); - self.interned_value(id) - } - - // We won any races so should intern the data - dashmap::mapref::entry::Entry::Vacant(entry) => { - let next_id = self.counter.fetch_add(1); - let next_id = crate::id::Id::from_u32(next_id); - let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value { - id: next_id, - fields: internal_data, - })); - let value_raw = value.as_raw(); - drop(value); - entry.insert(next_id); - // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. - unsafe { C::struct_from_raw(value_raw) } - } + // We won any races so should intern the data + dashmap::mapref::entry::Entry::Vacant(entry) => { + let next_id = self.counter.fetch_add(1); + let next_id = crate::id::Id::from_u32(next_id); + let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value { + id: next_id, + fields: internal_data, + })); + let value_raw = value.as_raw(); + drop(value); + entry.insert(next_id); + // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. + unsafe { C::struct_from_raw(value_raw) } } - }) + } } pub fn interned_value(&self, id: Id) -> C::Struct<'_> { diff --git a/src/key.rs b/src/key.rs index b2b70292..4cb5dd4c 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,4 @@ -use crate::{cycle::CycleRecoveryStrategy, local_state, storage::IngredientIndex, Database, Id}; +use crate::{cycle::CycleRecoveryStrategy, storage::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -60,7 +60,7 @@ impl DependencyIndex { impl std::fmt::Debug for DependencyIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - local_state::with_attached_database(|db| { + crate::attach::with_attached_database(|db| { let ingredient = db.zalsa().lookup_ingredient(self.ingredient_index); ingredient.fmt_index(self.key_index, f) }) diff --git a/src/lib.rs b/src/lib.rs index 8aa9685e..42e5ad00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ mod accumulator; mod active_query; mod alloc; mod array; +mod attach; mod cancelled; mod cycle; mod database; @@ -41,7 +42,7 @@ pub use self::key::DatabaseKeyIndex; pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::update::Update; -pub use crate::local_state::with_attached_database; +pub use crate::attach::with_attached_database; pub use salsa_macros::accumulator; pub use salsa_macros::db; pub use salsa_macros::input; @@ -63,6 +64,8 @@ pub mod prelude { pub mod plumbing { pub use crate::accumulator::Accumulator; pub use crate::array::Array; + pub use crate::attach::attach; + pub use crate::attach::with_attached_database; pub use crate::cycle::Cycle; pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; @@ -76,7 +79,6 @@ pub mod plumbing { pub use crate::ingredient::Ingredient; pub use crate::ingredient::Jar; pub use crate::key::DatabaseKeyIndex; - pub use crate::local_state::with_attached_database; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/local_state.rs b/src/local_state.rs index b9a16c4c..606415a2 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -13,49 +13,16 @@ use crate::Database; use crate::Event; use crate::EventKind; use crate::Revision; -use std::cell::Cell; use std::cell::RefCell; -use std::ptr::NonNull; use std::sync::Arc; -thread_local! { - /// The thread-local state salsa requires for a given thread - static LOCAL_STATE: LocalState = const { LocalState::new() } -} - -/// Attach the database to the current thread and execute `op`. -/// Panics if a different database has already been attached. -pub(crate) fn attach(db: &DB, op: impl FnOnce(&LocalState) -> R) -> R -where - DB: ?Sized + Database, -{ - LOCAL_STATE.with(|state| state.attach(db.as_dyn_database(), || op(state))) -} - -/// Access the "attached" database. Returns `None` if no database is attached. -/// Databases are attached with `attach_database`. -pub fn with_attached_database(op: impl FnOnce(&dyn Database) -> R) -> Option { - LOCAL_STATE.with(|state| { - if let Some(db) = state.database.get() { - // SAFETY: We always attach the database in for the entire duration of a function, - // so it cannot become "unattached" while this function is running. - Some(op(unsafe { db.as_ref() })) - } else { - None - } - }) -} - /// State that is specific to a single execution thread. /// /// Internally, this type uses ref-cells. /// /// **Note also that all mutations to the database handle (and hence /// to the local-state) must be undone during unwinding.** -pub(crate) struct LocalState { - /// Pointer to the currently attached database. - database: Cell>>, - +pub struct LocalState { /// Vector of active queries. /// /// This is normally `Some`, but it is set to `None` @@ -67,56 +34,12 @@ pub(crate) struct LocalState { } impl LocalState { - const fn new() -> Self { + pub(crate) fn new() -> Self { LocalState { - database: Cell::new(None), query_stack: RefCell::new(Some(vec![])), } } - fn attach(&self, db: &dyn Database, op: impl FnOnce() -> R) -> R { - struct DbGuard<'s> { - state: Option<&'s LocalState>, - } - - impl<'s> DbGuard<'s> { - fn new(state: &'s LocalState, db: &dyn Database) -> Self { - if let Some(current_db) = state.database.get() { - let new_db = NonNull::from(db); - - // Already attached? Assert that the database has not changed. - // NOTE: It's important to use `addr_eq` here because `NonNull::eq` not only compares the address but also the type's metadata. - if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) { - panic!( - "Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}", - ); - } - - Self { state: None } - } else { - // Otherwise, set the database. - state.database.set(Some(NonNull::from(db))); - Self { state: Some(state) } - } - } - } - - impl Drop for DbGuard<'_> { - fn drop(&mut self) { - // Reset database to null if we did anything in `DbGuard::new`. - if let Some(state) = self.state { - state.database.set(None); - - // All stack frames should have been popped from the local stack. - assert!(state.query_stack.borrow().as_ref().unwrap().is_empty()); - } - } - } - - let _guard = DbGuard::new(self, db); - op() - } - #[inline] pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { let mut query_stack = self.query_stack.borrow_mut(); diff --git a/src/storage.rs b/src/storage.rs index 77fc6210..d37ada2c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,12 +5,12 @@ use parking_lot::Mutex; use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; -use crate::database::{DatabaseImpl, UserData}; +use crate::database::UserData; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; -use crate::views::{Views, ViewsOf}; -use crate::{Database, Durability, Revision}; +use crate::views::Views; +use crate::{Database, DatabaseImpl, Durability, Revision}; pub fn views(db: &Db) -> &Views { db.zalsa().views() @@ -191,7 +191,7 @@ impl IngredientIndex { pub(crate) struct ZalsaImpl { user_data: U, - views_of: ViewsOf>, + views_of: Views, nonce: Nonce, @@ -227,7 +227,7 @@ impl Default for ZalsaImpl { impl ZalsaImpl { pub(crate) fn with(user_data: U) -> Self { Self { - views_of: Default::default(), + views_of: Views::new::>(), nonce: NONCE.nonce(), jar_map: Default::default(), ingredients_vec: Default::default(), diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 45136529..b8460c34 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -11,7 +11,7 @@ use crate::{ ingredient::{fmt_index, Ingredient, Jar}, ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::{self, QueryOrigin}, + local_state::QueryOrigin, salsa_struct::SalsaStructInDb, storage::IngredientIndex, Database, Durability, Event, Id, Revision, @@ -290,86 +290,85 @@ where db: &'db dyn Database, fields: C::Fields<'db>, ) -> C::Struct<'db> { - local_state::attach(db, |local_state| { - let zalsa = db.zalsa(); - - let data_hash = crate::hash::hash(&C::id_fields(&fields)); - - let (query_key, current_deps, disambiguator) = - local_state.disambiguate(self.ingredient_index, Revision::start(), data_hash); - - let entity_key = KeyStruct { - query_key, - disambiguator, - data_hash, - }; - - let (id, new_id) = self.intern(entity_key); - local_state.add_output(self.database_key_index(id).into()); - - let current_revision = zalsa.current_revision(); - if new_id { - // This is a new tracked struct, so create an entry in the struct map. - - self.struct_map.insert( - current_revision, - Value { - id, - key: entity_key, - struct_ingredient_index: self.ingredient_index, - created_at: current_revision, - durability: current_deps.durability, - fields: unsafe { self.to_static(fields) }, - revisions: C::new_revisions(current_deps.changed_at), - }, - ) - } else { - // The struct already exists in the intern map. - // Note that we assume there is at most one executing copy of - // the current query at a time, which implies that the - // struct must exist in `self.struct_map` already - // (if the same query could execute twice in parallel, - // then it would potentially create the same struct twice in parallel, - // which means the interned key could exist but `struct_map` not yet have - // been updated). - - match self.struct_map.update(current_revision, id) { - Update::Current(r) => { - // All inputs up to this point were previously - // observed to be green and this struct was already - // verified. Therefore, the durability ought not to have - // changed (nor the field values, but the user could've - // done something stupid, so we can't *assert* this is true). - assert!(C::deref_struct(r).durability == current_deps.durability); - - r + let zalsa = db.zalsa(); + let zalsa_local = db.zalsa_local(); + + let data_hash = crate::hash::hash(&C::id_fields(&fields)); + + let (query_key, current_deps, disambiguator) = + zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash); + + let entity_key = KeyStruct { + query_key, + disambiguator, + data_hash, + }; + + let (id, new_id) = self.intern(entity_key); + zalsa_local.add_output(self.database_key_index(id).into()); + + let current_revision = zalsa.current_revision(); + if new_id { + // This is a new tracked struct, so create an entry in the struct map. + + self.struct_map.insert( + current_revision, + Value { + id, + key: entity_key, + struct_ingredient_index: self.ingredient_index, + created_at: current_revision, + durability: current_deps.durability, + fields: unsafe { self.to_static(fields) }, + revisions: C::new_revisions(current_deps.changed_at), + }, + ) + } else { + // The struct already exists in the intern map. + // Note that we assume there is at most one executing copy of + // the current query at a time, which implies that the + // struct must exist in `self.struct_map` already + // (if the same query could execute twice in parallel, + // then it would potentially create the same struct twice in parallel, + // which means the interned key could exist but `struct_map` not yet have + // been updated). + + match self.struct_map.update(current_revision, id) { + Update::Current(r) => { + // All inputs up to this point were previously + // observed to be green and this struct was already + // verified. Therefore, the durability ought not to have + // changed (nor the field values, but the user could've + // done something stupid, so we can't *assert* this is true). + assert!(C::deref_struct(r).durability == current_deps.durability); + + r + } + Update::Outdated(mut data_ref) => { + let data = &mut *data_ref; + + // SAFETY: We assert that the pointer to `data.revisions` + // is a pointer into the database referencing a value + // from a previous revision. As such, it continues to meet + // its validity invariant and any owned content also continues + // to meet its safety invariant. + unsafe { + C::update_fields( + current_revision, + &mut data.revisions, + self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), + fields, + ); } - Update::Outdated(mut data_ref) => { - let data = &mut *data_ref; - - // SAFETY: We assert that the pointer to `data.revisions` - // is a pointer into the database referencing a value - // from a previous revision. As such, it continues to meet - // its validity invariant and any owned content also continues - // to meet its safety invariant. - unsafe { - C::update_fields( - current_revision, - &mut data.revisions, - self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), - fields, - ); - } - if current_deps.durability < data.durability { - data.revisions = C::new_revisions(current_revision); - } - data.durability = current_deps.durability; - data.created_at = current_revision; - data_ref.freeze() + if current_deps.durability < data.durability { + data.revisions = C::new_revisions(current_revision); } + data.durability = current_deps.durability; + data.created_at = current_revision; + data_ref.freeze() } } - }) + } } /// Given the id of a tracked struct created in this revision, @@ -520,21 +519,20 @@ where db: &dyn crate::Database, field_index: usize, ) -> &'db C::Fields<'db> { - local_state::attach(db, |local_state| { - let field_ingredient_index = self.struct_ingredient_index.successor(field_index); - let changed_at = self.revisions[field_index]; - - local_state.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(self.id.as_id()), - }, - self.durability, - changed_at, - ); + let zalsa_local = db.zalsa_local(); + let field_ingredient_index = self.struct_ingredient_index.successor(field_index); + let changed_at = self.revisions[field_index]; + + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(self.id.as_id()), + }, + self.durability, + changed_at, + ); - unsafe { self.to_self_ref(&self.fields) } - }) + unsafe { self.to_self_ref(&self.fields) } } unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 7a6b7b42..2061d1f0 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,6 +1,5 @@ use crate::{ - id::AsId, ingredient::Ingredient, key::DependencyIndex, local_state, storage::IngredientIndex, - Database, Id, + id::AsId, ingredient::Ingredient, key::DependencyIndex, storage::IngredientIndex, Database, Id, }; use super::{struct_map::StructMapView, Configuration}; @@ -47,23 +46,22 @@ where /// Note that this function returns the entire tuple of value fields. /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - local_state::attach(db, |local_state| { - let current_revision = db.zalsa().current_revision(); - let data = self.struct_map.get(current_revision, id); - let data = C::deref_struct(data); - let changed_at = data.revisions[self.field_index]; - - local_state.report_tracked_read( - DependencyIndex { - ingredient_index: self.ingredient_index, - key_index: Some(id.as_id()), - }, - data.durability, - changed_at, - ); - - unsafe { self.to_self_ref(&data.fields) } - }) + let zalsa_local = db.zalsa_local(); + let current_revision = db.zalsa().current_revision(); + let data = self.struct_map.get(current_revision, id); + let data = C::deref_struct(data); + let changed_at = data.revisions[self.field_index]; + + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: self.ingredient_index, + key_index: Some(id.as_id()), + }, + data.durability, + changed_at, + ); + + unsafe { self.to_self_ref(&data.fields) } } } diff --git a/src/views.rs b/src/views.rs index 75369e75..19798737 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,7 +1,5 @@ use std::{ any::{Any, TypeId}, - marker::PhantomData, - ops::Deref, sync::Arc, }; @@ -9,11 +7,6 @@ use orx_concurrent_vec::ConcurrentVec; use crate::Database; -pub struct ViewsOf { - upcasts: Views, - phantom: PhantomData, -} - #[derive(Clone)] pub struct Views { source_type_id: TypeId, @@ -29,25 +22,8 @@ struct ViewCaster { #[allow(dead_code)] enum Dummy {} -impl Default for ViewsOf { - fn default() -> Self { - Self { - upcasts: Views::new::(), - phantom: Default::default(), - } - } -} - -impl Deref for ViewsOf { - type Target = Views; - - fn deref(&self) -> &Self::Target { - &self.upcasts - } -} - impl Views { - fn new() -> Self { + pub(crate) fn new() -> Self { let source_type_id = TypeId::of::(); Self { source_type_id, @@ -127,12 +103,3 @@ fn data_ptr(t: &T) -> &() { let u: *const () = t as *const (); unsafe { &*u } } - -impl Clone for ViewsOf { - fn clone(&self) -> Self { - Self { - upcasts: self.upcasts.clone(), - phantom: self.phantom, - } - } -} diff --git a/tests/accumulate.rs b/tests/accumulate.rs index e20cc05f..ea16666d 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -55,7 +55,7 @@ fn push_b_logs(db: &dyn LogDatabase, input: MyInput) { #[test] fn accumulate_once() { - let mut db = salsa::DatabaseImpl::with(Logger::default()); + let db = salsa::DatabaseImpl::with(Logger::default()); // Just call accumulate on a base input to see what happens. let input = MyInput::new(&db, 2, 3); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index cb741f37..752d9e4c 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -25,7 +25,7 @@ pub trait LogDatabase: HasLogger + salsa::Database { /// Asserts what the (formatted) logs should look like, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. - fn assert_logs(&mut self, expected: expect_test::Expect) { + fn assert_logs(&self, expected: expect_test::Expect) { let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap()); expected.assert_eq(&format!("{:#?}", logs)); } @@ -33,7 +33,7 @@ pub trait LogDatabase: HasLogger + salsa::Database { /// Asserts the length of the logs, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. - fn assert_logs_len(&mut self, expected: usize) { + fn assert_logs_len(&self, expected: usize) { let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap()); assert_eq!(logs.len(), expected); } From f8b1620ca7379b0662d91620f637b7e2212d0baf Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 10:01:35 +0000 Subject: [PATCH 11/29] pacify the merciless cargo fmt --- src/storage.rs | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index d37ada2c..689b4eab 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -41,10 +41,7 @@ pub trait Zalsa { fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient; /// Gets an `&mut`-ref to an ingredient by index. - fn lookup_ingredient_mut( - &mut self, - index: IngredientIndex, - ) -> &mut dyn Ingredient; + fn lookup_ingredient_mut(&mut self, index: IngredientIndex) -> &mut dyn Ingredient; fn runtimex(&self) -> &Runtime; @@ -88,11 +85,11 @@ impl Zalsa for ZalsaImpl { let ingredients = jar.create_ingredients(index); for ingredient in ingredients { let expected_index = ingredient.ingredient_index(); - + if ingredient.requires_reset_for_new_revision() { self.ingredients_requiring_reset.push(expected_index); } - + let actual_index = self .ingredients_vec .push(ingredient); @@ -104,7 +101,7 @@ impl Zalsa for ZalsaImpl { expected_index, actual_index, ); - + } index }) @@ -115,33 +112,30 @@ impl Zalsa for ZalsaImpl { &**self.ingredients_vec.get(index.as_usize()).unwrap() } - fn lookup_ingredient_mut( - &mut self, - index: IngredientIndex, - ) -> &mut dyn Ingredient { + fn lookup_ingredient_mut(&mut self, index: IngredientIndex) -> &mut dyn Ingredient { &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap() } fn current_revision(&self) -> Revision { self.runtime.current_revision() } - + fn load_cancellation_flag(&self) -> bool { self.runtime.load_cancellation_flag() } - + fn report_tracked_write(&mut self, durability: Durability) { self.runtime.report_tracked_write(durability) } - + fn runtimex(&self) -> &Runtime { &self.runtime } - + fn last_changed_revision(&self, durability: Durability) -> Revision { self.runtime.last_changed_revision(durability) } - + fn set_cancellation_flag(&self) { self.runtime.set_cancellation_flag() } @@ -190,7 +184,7 @@ impl IngredientIndex { /// It is shared between the main database and any active snapshots. pub(crate) struct ZalsaImpl { user_data: U, - + views_of: Views, nonce: Nonce, From 85628247e5a80d1063704317ef4cbc6af6e824dc Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 12:42:34 +0000 Subject: [PATCH 12/29] pacify the merciless clippy --- src/attach.rs | 12 +++++------- src/database.rs | 5 +++-- src/tracked_struct/struct_map.rs | 16 ++++++++-------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/attach.rs b/src/attach.rs index 3dcf9d11..9f276596 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -74,13 +74,11 @@ impl Attached { /// Access the "attached" database. Returns `None` if no database is attached. /// Databases are attached with `attach_database`. fn with(&self, op: impl FnOnce(&dyn Database) -> R) -> Option { - if let Some(db) = self.database.get() { - // SAFETY: We always attach the database in for the entire duration of a function, - // so it cannot become "unattached" while this function is running. - Some(op(unsafe { db.as_ref() })) - } else { - None - } + let db = self.database.get()?; + + // SAFETY: We always attach the database in for the entire duration of a function, + // so it cannot become "unattached" while this function is running. + Some(op(unsafe { db.as_ref() })) } } diff --git a/src/database.rs b/src/database.rs index 3643c6dd..3680d709 100644 --- a/src/database.rs +++ b/src/database.rs @@ -12,10 +12,11 @@ use crate::{ /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]` procedural macro. /// -/// # Safety conditions +/// # Safety /// /// This trait can only safely be implemented by Salsa's [`DatabaseImpl`][] type. -/// FIXME: Document better the unsafety conditions we guarantee. +/// +/// FIXME: Document better the unsafety conditions we require. #[salsa_macros::db] pub unsafe trait Database: Send + AsDynDatabase + Any { /// This function is invoked by the salsa runtime at various points during execution. diff --git a/src/tracked_struct/struct_map.rs b/src/tracked_struct/struct_map.rs index b8ea1578..2502f485 100644 --- a/src/tracked_struct/struct_map.rs +++ b/src/tracked_struct/struct_map.rs @@ -80,7 +80,7 @@ where /// /// * If value with same `value.id` is already present in the map. /// * If value not created in current revision. - pub fn insert<'db>(&'db self, current_revision: Revision, value: Value) -> C::Struct<'db> { + pub fn insert(&self, current_revision: Revision, value: Value) -> C::Struct<'_> { assert_eq!(value.created_at, current_revision); let id = value.id; @@ -99,7 +99,7 @@ where unsafe { C::struct_from_raw(pointer) } } - pub fn validate<'db>(&'db self, current_revision: Revision, id: Id) { + pub fn validate(&self, current_revision: Revision, id: Id) { let mut data = self.map.get_mut(&id).unwrap(); // UNSAFE: We never permit `&`-access in the current revision until data.created_at @@ -118,7 +118,7 @@ where /// /// * If the value is not present in the map. /// * If the value is already updated in this revision. - pub fn update<'db>(&'db self, current_revision: Revision, id: Id) -> Update<'db, C> { + pub fn update(&self, current_revision: Revision, id: Id) -> Update<'_, C> { let mut data = self.map.get_mut(&id).unwrap(); // UNSAFE: We never permit `&`-access in the current revision until data.created_at @@ -163,7 +163,7 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - pub fn get<'db>(&'db self, current_revision: Revision, id: Id) -> C::Struct<'db> { + pub fn get(&self, current_revision: Revision, id: Id) -> C::Struct<'_> { Self::get_from_map(&self.map, current_revision, id) } @@ -173,11 +173,11 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - fn get_from_map<'db>( - map: &'db FxDashMap>>, + fn get_from_map( + map: &FxDashMap>>, current_revision: Revision, id: Id, - ) -> C::Struct<'db> { + ) -> C::Struct<'_> { let data = map.get(&id).unwrap(); // UNSAFE: We permit `&`-access in the current revision once data.created_at @@ -230,7 +230,7 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - pub fn get<'db>(&'db self, current_revision: Revision, id: Id) -> C::Struct<'db> { + pub fn get(&self, current_revision: Revision, id: Id) -> C::Struct<'_> { StructMap::get_from_map(&self.map, current_revision, id) } } From 34e109d3905b3a8df52030c34c73ce22cd3c29cf Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 13:01:09 +0000 Subject: [PATCH 13/29] remove type parameter from `ZalsaImpl` --- src/database.rs | 13 +++++++++---- src/storage.rs | 26 +++++++++----------------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/database.rs b/src/database.rs index 3680d709..e672c64e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,4 +1,4 @@ -use std::{any::Any, panic::RefUnwindSafe, sync::Arc}; +use std::{any::Any, marker::PhantomData, panic::RefUnwindSafe, sync::Arc}; use parking_lot::{Condvar, Mutex}; @@ -113,7 +113,7 @@ impl dyn Database { /// Takes an optional type parameter `U` that allows you to thread your own data. pub struct DatabaseImpl { /// Reference to the database. This is always `Some` except during destruction. - zalsa_impl: Option>>, + zalsa_impl: Option>, /// Coordination data for cancellation of other handles when `zalsa_mut` is called. /// This could be stored in ZalsaImpl but it makes things marginally cleaner to keep it separate. @@ -121,6 +121,9 @@ pub struct DatabaseImpl { /// Per-thread state zalsa_local: local_state::LocalState, + + /// The `U` is stored as a `dyn Any` in `zalsa_impl` + phantom: PhantomData, } impl Default for DatabaseImpl { @@ -150,13 +153,14 @@ impl DatabaseImpl { cvar: Default::default(), }), zalsa_local: LocalState::new(), + phantom: PhantomData::, } } /// Access the `Arc`. This should always be /// possible as `zalsa_impl` only becomes /// `None` once we are in the `Drop` impl. - fn zalsa_impl(&self) -> &Arc> { + fn zalsa_impl(&self) -> &Arc { self.zalsa_impl.as_ref().unwrap() } @@ -188,7 +192,7 @@ impl std::ops::Deref for DatabaseImpl { type Target = U; fn deref(&self) -> &U { - self.zalsa_impl().user_data() + self.zalsa_impl().user_data().downcast_ref::().unwrap() } } @@ -228,6 +232,7 @@ impl Clone for DatabaseImpl { zalsa_impl: self.zalsa_impl.clone(), coordinate: Arc::clone(&self.coordinate), zalsa_local: LocalState::new(), + phantom: PhantomData::, } } } diff --git a/src/storage.rs b/src/storage.rs index 689b4eab..1726e44d 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,4 +1,4 @@ -use std::any::TypeId; +use std::any::{Any, TypeId}; use orx_concurrent_vec::ConcurrentVec; use parking_lot::Mutex; @@ -61,7 +61,7 @@ pub trait Zalsa { fn report_tracked_write(&mut self, durability: Durability); } -impl Zalsa for ZalsaImpl { +impl Zalsa for ZalsaImpl { fn views(&self) -> &Views { &self.views_of } @@ -182,8 +182,8 @@ impl IngredientIndex { /// The "storage" struct stores all the data for the jars. /// It is shared between the main database and any active snapshots. -pub(crate) struct ZalsaImpl { - user_data: U, +pub(crate) struct ZalsaImpl { + user_data: Box, views_of: Views, @@ -210,16 +210,8 @@ pub(crate) struct ZalsaImpl { runtime: Runtime, } -// ANCHOR: default -impl Default for ZalsaImpl { - fn default() -> Self { - Self::with(Default::default()) - } -} -// ANCHOR_END: default - -impl ZalsaImpl { - pub(crate) fn with(user_data: U) -> Self { +impl ZalsaImpl { + pub(crate) fn with(user_data: U) -> Self { Self { views_of: Views::new::>(), nonce: NONCE.nonce(), @@ -227,12 +219,12 @@ impl ZalsaImpl { ingredients_vec: Default::default(), ingredients_requiring_reset: Default::default(), runtime: Runtime::default(), - user_data, + user_data: Box::new(user_data), } } - pub(crate) fn user_data(&self) -> &U { - &self.user_data + pub(crate) fn user_data(&self) -> &(dyn Any + Send + Sync) { + &*self.user_data } /// Triggers a new revision. Invoked automatically when you call `zalsa_mut` From 703e12def8a7ffb2fb8d4d19814bce99539fb11e Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 13:06:30 +0000 Subject: [PATCH 14/29] remove the Zalsa trait and make it a struct --- src/database.rs | 20 +-- src/function/maybe_changed_after.rs | 2 +- src/function/memo.rs | 2 +- src/storage.rs | 212 +++++++++++----------------- 4 files changed, 95 insertions(+), 141 deletions(-) diff --git a/src/database.rs b/src/database.rs index e672c64e..f92e0992 100644 --- a/src/database.rs +++ b/src/database.rs @@ -5,7 +5,7 @@ use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, local_state::{self, LocalState}, - storage::{Zalsa, ZalsaImpl}, + storage::Zalsa, Durability, Event, EventKind, Revision, }; @@ -61,14 +61,14 @@ pub unsafe trait Database: Send + AsDynDatabase + Any { /// Plumbing method: Access the internal salsa methods. #[doc(hidden)] - fn zalsa(&self) -> &dyn Zalsa; + fn zalsa(&self) -> &Zalsa; /// Plumbing method: Access the internal salsa methods for mutating the database. /// /// **WARNING:** Triggers a new revision, canceling other database handles. /// This can lead to deadlock! #[doc(hidden)] - fn zalsa_mut(&mut self) -> &mut dyn Zalsa; + fn zalsa_mut(&mut self) -> &mut Zalsa; /// Access the thread-local state associated with this database #[doc(hidden)] @@ -113,10 +113,10 @@ impl dyn Database { /// Takes an optional type parameter `U` that allows you to thread your own data. pub struct DatabaseImpl { /// Reference to the database. This is always `Some` except during destruction. - zalsa_impl: Option>, + zalsa_impl: Option>, /// Coordination data for cancellation of other handles when `zalsa_mut` is called. - /// This could be stored in ZalsaImpl but it makes things marginally cleaner to keep it separate. + /// This could be stored in Zalsa but it makes things marginally cleaner to keep it separate. coordinate: Arc, /// Per-thread state @@ -147,7 +147,7 @@ impl DatabaseImpl { /// You can also use the [`Default`][] trait if your userdata implements it. pub fn with(u: U) -> Self { Self { - zalsa_impl: Some(Arc::new(ZalsaImpl::with(u))), + zalsa_impl: Some(Arc::new(Zalsa::with(u))), coordinate: Arc::new(Coordinate { clones: Mutex::new(1), cvar: Default::default(), @@ -157,10 +157,10 @@ impl DatabaseImpl { } } - /// Access the `Arc`. This should always be + /// Access the `Arc`. This should always be /// possible as `zalsa_impl` only becomes /// `None` once we are in the `Drop` impl. - fn zalsa_impl(&self) -> &Arc { + fn zalsa_impl(&self) -> &Arc { self.zalsa_impl.as_ref().unwrap() } @@ -200,11 +200,11 @@ impl RefUnwindSafe for DatabaseImpl {} #[salsa_macros::db] unsafe impl Database for DatabaseImpl { - fn zalsa(&self) -> &dyn Zalsa { + fn zalsa(&self) -> &Zalsa { &**self.zalsa_impl() } - fn zalsa_mut(&mut self) -> &mut dyn Zalsa { + fn zalsa_mut(&mut self) -> &mut Zalsa { self.cancel_others(); // The ref count on the `Arc` should now be 1 diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 0d3fc3d4..0460e6db 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -98,7 +98,7 @@ where pub(super) fn shallow_verify_memo( &self, db: &C::DbView, - zalsa: &dyn Zalsa, + zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { diff --git a/src/function/memo.rs b/src/function/memo.rs index 4413d71a..baeed166 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -129,7 +129,7 @@ impl Memo { } } /// True if this memo is known not to have changed based on its durability. - pub(super) fn check_durability(&self, zalsa: &dyn Zalsa) -> bool { + pub(super) fn check_durability(&self, zalsa: &Zalsa) -> bool { let last_changed = zalsa.last_changed_revision(self.revisions.durability); let verified_at = self.verified_at.load(); tracing::debug!( diff --git a/src/storage.rs b/src/storage.rs index 1726e44d..1bc40607 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -16,131 +16,6 @@ pub fn views(db: &Db) -> &Views { db.zalsa().views() } -/// The "plumbing interface" to the Salsa database. -/// -/// **NOT SEMVER STABLE.** -pub trait Zalsa { - /// Returns a reference to the underlying. - fn views(&self) -> &Views; - - /// Returns the nonce for the underyling storage. - /// - /// # Safety - /// - /// This nonce is guaranteed to be unique for the database and never to be reused. - fn nonce(&self) -> Nonce; - - /// Lookup the index assigned to the given jar (if any). This lookup is based purely on the jar's type. - fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option; - - /// Adds a jar to the database, returning the index of the first ingredient. - /// If a jar of this type is already present, returns the existing index. - fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex; - - /// Gets an `&`-ref to an ingredient by index - fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient; - - /// Gets an `&mut`-ref to an ingredient by index. - fn lookup_ingredient_mut(&mut self, index: IngredientIndex) -> &mut dyn Ingredient; - - fn runtimex(&self) -> &Runtime; - - /// Return the current revision - fn current_revision(&self) -> Revision; - - /// Return the time when an input of durability `durability` last changed - fn last_changed_revision(&self, durability: Durability) -> Revision; - - /// True if any threads have signalled for cancellation - fn load_cancellation_flag(&self) -> bool; - - /// Signal for cancellation, indicating current thread is trying to get unique access. - fn set_cancellation_flag(&self); - - /// Reports a (synthetic) tracked write to "some input of the given durability". - fn report_tracked_write(&mut self, durability: Durability); -} - -impl Zalsa for ZalsaImpl { - fn views(&self) -> &Views { - &self.views_of - } - - fn nonce(&self) -> Nonce { - self.nonce - } - - fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { - self.jar_map.lock().get(&jar.type_id()).copied() - } - - fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { - { - let jar_type_id = jar.type_id(); - let mut jar_map = self.jar_map.lock(); - *jar_map - .entry(jar_type_id) - .or_insert_with(|| { - let index = IngredientIndex::from(self.ingredients_vec.len()); - let ingredients = jar.create_ingredients(index); - for ingredient in ingredients { - let expected_index = ingredient.ingredient_index(); - - if ingredient.requires_reset_for_new_revision() { - self.ingredients_requiring_reset.push(expected_index); - } - - let actual_index = self - .ingredients_vec - .push(ingredient); - assert_eq!( - expected_index.as_usize(), - actual_index, - "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", - self.ingredients_vec.get(actual_index).unwrap(), - expected_index, - actual_index, - ); - - } - index - }) - } - } - - fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { - &**self.ingredients_vec.get(index.as_usize()).unwrap() - } - - fn lookup_ingredient_mut(&mut self, index: IngredientIndex) -> &mut dyn Ingredient { - &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap() - } - - fn current_revision(&self) -> Revision { - self.runtime.current_revision() - } - - fn load_cancellation_flag(&self) -> bool { - self.runtime.load_cancellation_flag() - } - - fn report_tracked_write(&mut self, durability: Durability) { - self.runtime.report_tracked_write(durability) - } - - fn runtimex(&self) -> &Runtime { - &self.runtime - } - - fn last_changed_revision(&self, durability: Durability) -> Revision { - self.runtime.last_changed_revision(durability) - } - - fn set_cancellation_flag(&self) { - self.runtime.set_cancellation_flag() - } -} - /// Nonce type representing the underlying database storage. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct StorageNonce; @@ -180,9 +55,10 @@ impl IngredientIndex { } } -/// The "storage" struct stores all the data for the jars. -/// It is shared between the main database and any active snapshots. -pub(crate) struct ZalsaImpl { +/// The "plumbing interface" to the Salsa database. Stores all the ingredients and other data. +/// +/// **NOT SEMVER STABLE.** +pub struct Zalsa { user_data: Box, views_of: Views, @@ -210,7 +86,7 @@ pub(crate) struct ZalsaImpl { runtime: Runtime, } -impl ZalsaImpl { +impl Zalsa { pub(crate) fn with(user_data: U) -> Self { Self { views_of: Views::new::>(), @@ -223,6 +99,84 @@ impl ZalsaImpl { } } + pub(crate) fn views(&self) -> &Views { + &self.views_of + } + + pub(crate) fn nonce(&self) -> Nonce { + self.nonce + } + + /// **NOT SEMVER STABLE** + pub fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { + { + let jar_type_id = jar.type_id(); + let mut jar_map = self.jar_map.lock(); + *jar_map + .entry(jar_type_id) + .or_insert_with(|| { + let index = IngredientIndex::from(self.ingredients_vec.len()); + let ingredients = jar.create_ingredients(index); + for ingredient in ingredients { + let expected_index = ingredient.ingredient_index(); + + if ingredient.requires_reset_for_new_revision() { + self.ingredients_requiring_reset.push(expected_index); + } + + let actual_index = self + .ingredients_vec + .push(ingredient); + assert_eq!( + expected_index.as_usize(), + actual_index, + "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", + self.ingredients_vec.get(actual_index).unwrap(), + expected_index, + actual_index, + ); + + } + index + }) + } + } + + pub(crate) fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { + &**self.ingredients_vec.get(index.as_usize()).unwrap() + } + + /// **NOT SEMVER STABLE** + pub fn lookup_ingredient_mut(&mut self, index: IngredientIndex) -> &mut dyn Ingredient { + &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap() + } + + /// **NOT SEMVER STABLE** + pub fn current_revision(&self) -> Revision { + self.runtime.current_revision() + } + + pub(crate) fn load_cancellation_flag(&self) -> bool { + self.runtime.load_cancellation_flag() + } + + pub(crate) fn report_tracked_write(&mut self, durability: Durability) { + self.runtime.report_tracked_write(durability) + } + + pub(crate) fn runtimex(&self) -> &Runtime { + &self.runtime + } + + /// **NOT SEMVER STABLE** + pub fn last_changed_revision(&self, durability: Durability) -> Revision { + self.runtime.last_changed_revision(durability) + } + + pub(crate) fn set_cancellation_flag(&self) { + self.runtime.set_cancellation_flag() + } + pub(crate) fn user_data(&self) -> &(dyn Any + Send + Sync) { &*self.user_data } From d141cd850044e1554876c4d43b62c098166b5a44 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 13:10:16 +0000 Subject: [PATCH 15/29] encapsulate Runtime within Zalsa The aim is to eventually eliminate Runtime. --- src/function/sync.rs | 14 +++++++------- src/storage.rs | 33 +++++++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/function/sync.rs b/src/function/sync.rs index 0f1d5178..6d5e2b8b 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -4,8 +4,8 @@ use std::{ }; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, local_state::LocalState, runtime::WaitResult, Database, - Id, Runtime, + hash::FxDashMap, key::DatabaseKeyIndex, local_state::LocalState, runtime::WaitResult, + storage::Zalsa, Database, Id, }; #[derive(Default)] @@ -28,7 +28,7 @@ impl SyncMap { local_state: &LocalState, database_key_index: DatabaseKeyIndex, ) -> Option> { - let runtime = db.zalsa().runtimex(); + let zalsa = db.zalsa(); let thread_id = std::thread::current().id(); match self.sync_map.entry(database_key_index.key_index) { dashmap::mapref::entry::Entry::Vacant(entry) => { @@ -38,7 +38,7 @@ impl SyncMap { }); Some(ClaimGuard { database_key: database_key_index, - runtime, + zalsa, sync_map: &self.sync_map, }) } @@ -51,7 +51,7 @@ impl SyncMap { // not to gate future atomic reads. entry.get().anyone_waiting.store(true, Ordering::Relaxed); let other_id = entry.get().id; - runtime.block_on_or_unwind(db, local_state, database_key_index, other_id, entry); + zalsa.block_on_or_unwind(db, local_state, database_key_index, other_id, entry); None } } @@ -63,7 +63,7 @@ impl SyncMap { #[must_use] pub(super) struct ClaimGuard<'me> { database_key: DatabaseKeyIndex, - runtime: &'me Runtime, + zalsa: &'me Zalsa, sync_map: &'me FxDashMap, } @@ -75,7 +75,7 @@ impl<'me> ClaimGuard<'me> { // NB: `Ordering::Relaxed` is sufficient here, // see `store` above for explanation. if anyone_waiting.load(Ordering::Relaxed) { - self.runtime + self.zalsa .unblock_queries_blocked_on(self.database_key, wait_result) } } diff --git a/src/storage.rs b/src/storage.rs index 1bc40607..292fbd1d 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,4 +1,5 @@ use std::any::{Any, TypeId}; +use std::thread::ThreadId; use orx_concurrent_vec::ConcurrentVec; use parking_lot::Mutex; @@ -7,10 +8,11 @@ use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; use crate::database::UserData; use crate::ingredient::{Ingredient, Jar}; +use crate::local_state::LocalState; use crate::nonce::{Nonce, NonceGenerator}; -use crate::runtime::Runtime; +use crate::runtime::{Runtime, WaitResult}; use crate::views::Views; -use crate::{Database, DatabaseImpl, Durability, Revision}; +use crate::{Database, DatabaseImpl, DatabaseKeyIndex, Durability, Revision}; pub fn views(db: &Db) -> &Views { db.zalsa().views() @@ -164,10 +166,6 @@ impl Zalsa { self.runtime.report_tracked_write(durability) } - pub(crate) fn runtimex(&self) -> &Runtime { - &self.runtime - } - /// **NOT SEMVER STABLE** pub fn last_changed_revision(&self, durability: Durability) -> Revision { self.runtime.last_changed_revision(durability) @@ -195,6 +193,29 @@ impl Zalsa { new_revision } + + /// See [`Runtime::block_on_or_unwind`][] + pub(crate) fn block_on_or_unwind( + &self, + db: &dyn Database, + local_state: &LocalState, + database_key: DatabaseKeyIndex, + other_id: ThreadId, + query_mutex_guard: QueryMutexGuard, + ) { + self.runtime + .block_on_or_unwind(db, local_state, database_key, other_id, query_mutex_guard) + } + + /// See [`Runtime::unblock_queries_blocked_on`][] + pub(crate) fn unblock_queries_blocked_on( + &self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + self.runtime + .unblock_queries_blocked_on(database_key, wait_result) + } } /// Caches a pointer to an ingredient in a database. From 905437754671b4b8b0b57f1b866f020e701eab1a Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 13:11:57 +0000 Subject: [PATCH 16/29] rename `storage` mod to `zalsa` --- src/accumulator.rs | 2 +- src/database.rs | 2 +- src/function.rs | 2 +- src/function/maybe_changed_after.rs | 2 +- src/function/memo.rs | 2 +- src/function/sync.rs | 2 +- src/ingredient.rs | 2 +- src/ingredient_list.rs | 2 +- src/input.rs | 4 ++-- src/input/input_field.rs | 2 +- src/interned.rs | 2 +- src/key.rs | 2 +- src/lib.rs | 10 +++++----- src/local_state.rs | 2 +- src/salsa_struct.rs | 2 +- src/tracked_struct.rs | 4 ++-- src/tracked_struct/tracked_field.rs | 2 +- src/{storage.rs => zalsa.rs} | 0 18 files changed, 23 insertions(+), 23 deletions(-) rename src/{storage.rs => zalsa.rs} (100%) diff --git a/src/accumulator.rs b/src/accumulator.rs index 47133790..ea989658 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -11,7 +11,7 @@ use crate::{ ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, local_state::{LocalState, QueryOrigin}, - storage::IngredientIndex, + zalsa::IngredientIndex, Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; diff --git a/src/database.rs b/src/database.rs index f92e0992..c32e1063 100644 --- a/src/database.rs +++ b/src/database.rs @@ -5,7 +5,7 @@ use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, local_state::{self, LocalState}, - storage::Zalsa, + zalsa::Zalsa, Durability, Event, EventKind, Revision, }; diff --git a/src/function.rs b/src/function.rs index 72313cab..176d43e9 100644 --- a/src/function.rs +++ b/src/function.rs @@ -4,7 +4,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ cycle::CycleRecoveryStrategy, ingredient::fmt_index, key::DatabaseKeyIndex, - local_state::QueryOrigin, salsa_struct::SalsaStructInDb, storage::IngredientIndex, + local_state::QueryOrigin, salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, AsDynDatabase as _, Cycle, Database, Event, EventKind, Id, Revision, }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 0460e6db..cb6ec20f 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -4,7 +4,7 @@ use crate::{ key::DatabaseKeyIndex, local_state::{ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, - storage::Zalsa, + zalsa::Zalsa, AsDynDatabase as _, Database, Id, Revision, }; diff --git a/src/function/memo.rs b/src/function/memo.rs index baeed166..e84782ca 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -4,7 +4,7 @@ use arc_swap::{ArcSwap, Guard}; use crossbeam::atomic::AtomicCell; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, local_state::QueryRevisions, storage::Zalsa, Event, + hash::FxDashMap, key::DatabaseKeyIndex, local_state::QueryRevisions, zalsa::Zalsa, Event, EventKind, Id, Revision, }; diff --git a/src/function/sync.rs b/src/function/sync.rs index 6d5e2b8b..72940083 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -5,7 +5,7 @@ use std::{ use crate::{ hash::FxDashMap, key::DatabaseKeyIndex, local_state::LocalState, runtime::WaitResult, - storage::Zalsa, Database, Id, + zalsa::Zalsa, Database, Id, }; #[derive(Default)] diff --git a/src/ingredient.rs b/src/ingredient.rs index 09056c1b..3f053604 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - cycle::CycleRecoveryStrategy, local_state::QueryOrigin, storage::IngredientIndex, Database, + cycle::CycleRecoveryStrategy, local_state::QueryOrigin, zalsa::IngredientIndex, Database, DatabaseKeyIndex, Id, }; diff --git a/src/ingredient_list.rs b/src/ingredient_list.rs index 3db9a5f5..92b54368 100644 --- a/src/ingredient_list.rs +++ b/src/ingredient_list.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use arc_swap::{ArcSwapOption, AsRaw}; -use crate::storage::IngredientIndex; +use crate::zalsa::IngredientIndex; /// A list of ingredients that can be added to in parallel. pub(crate) struct IngredientList { diff --git a/src/input.rs b/src/input.rs index 62fe321c..10c08c62 100644 --- a/src/input.rs +++ b/src/input.rs @@ -19,7 +19,7 @@ use crate::{ key::{DatabaseKeyIndex, DependencyIndex}, local_state::QueryOrigin, plumbing::{Jar, Stamp}, - storage::IngredientIndex, + zalsa::IngredientIndex, Database, Durability, Id, Revision, }; @@ -53,7 +53,7 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( &self, - struct_index: crate::storage::IngredientIndex, + struct_index: crate::zalsa::IngredientIndex, ) -> Vec> { let struct_ingredient: IngredientImpl = IngredientImpl::new(struct_index); let struct_map = struct_ingredient.struct_map.clone(); diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 9730a902..99f6f65a 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -2,7 +2,7 @@ use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{fmt_index, Ingredient}; use crate::input::Configuration; use crate::local_state::QueryOrigin; -use crate::storage::IngredientIndex; +use crate::zalsa::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id, Revision}; use std::fmt; diff --git a/src/interned.rs b/src/interned.rs index 55d9fc3b..e6e76d88 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -11,7 +11,7 @@ use crate::ingredient::fmt_index; use crate::key::DependencyIndex; use crate::local_state::QueryOrigin; use crate::plumbing::Jar; -use crate::storage::IngredientIndex; +use crate::zalsa::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id}; use super::hash::FxDashMap; diff --git a/src/key.rs b/src/key.rs index 4cb5dd4c..df49d047 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,4 @@ -use crate::{cycle::CycleRecoveryStrategy, storage::IngredientIndex, Database, Id}; +use crate::{cycle::CycleRecoveryStrategy, zalsa::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and diff --git a/src/lib.rs b/src/lib.rs index 42e5ad00..aa018b45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,10 +21,10 @@ mod nonce; mod revision; mod runtime; mod salsa_struct; -mod storage; mod tracked_struct; mod update; mod views; +mod zalsa; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; @@ -85,15 +85,15 @@ pub mod plumbing { pub use crate::runtime::Stamp; pub use crate::runtime::StampedValue; pub use crate::salsa_struct::SalsaStructInDb; - pub use crate::storage::views; - pub use crate::storage::IngredientCache; - pub use crate::storage::IngredientIndex; - pub use crate::storage::Zalsa; pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::always_update; pub use crate::update::helper::Dispatch as UpdateDispatch; pub use crate::update::helper::Fallback as UpdateFallback; pub use crate::update::Update; + pub use crate::zalsa::views; + pub use crate::zalsa::IngredientCache; + pub use crate::zalsa::IngredientIndex; + pub use crate::zalsa::Zalsa; pub use salsa_macro_rules::macro_if; pub use salsa_macro_rules::maybe_backdate; diff --git a/src/local_state.rs b/src/local_state.rs index 606415a2..3cb6652d 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -5,7 +5,7 @@ use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::key::DependencyIndex; use crate::runtime::StampedValue; -use crate::storage::IngredientIndex; +use crate::zalsa::IngredientIndex; use crate::tracked_struct::Disambiguator; use crate::Cancelled; use crate::Cycle; diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index 8e2d3c8d..1d45e94b 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1,4 +1,4 @@ -use crate::{storage::IngredientIndex, Database}; +use crate::{zalsa::IngredientIndex, Database}; pub trait SalsaStructInDb { fn register_dependent_fn(db: &dyn Database, index: IngredientIndex); diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index b8460c34..1d59ba2b 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -13,7 +13,7 @@ use crate::{ key::{DatabaseKeyIndex, DependencyIndex}, local_state::QueryOrigin, salsa_struct::SalsaStructInDb, - storage::IngredientIndex, + zalsa::IngredientIndex, Database, Durability, Event, Id, Revision, }; @@ -113,7 +113,7 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( &self, - struct_index: crate::storage::IngredientIndex, + struct_index: crate::zalsa::IngredientIndex, ) -> Vec> { let struct_ingredient = IngredientImpl::new(struct_index); let struct_map = &struct_ingredient.struct_map.view(); diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 2061d1f0..6121d833 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,5 +1,5 @@ use crate::{ - id::AsId, ingredient::Ingredient, key::DependencyIndex, storage::IngredientIndex, Database, Id, + id::AsId, ingredient::Ingredient, key::DependencyIndex, zalsa::IngredientIndex, Database, Id, }; use super::{struct_map::StructMapView, Configuration}; diff --git a/src/storage.rs b/src/zalsa.rs similarity index 100% rename from src/storage.rs rename to src/zalsa.rs From 3254f46ca80038db47f227674797aced3c07abfd Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 13:12:31 +0000 Subject: [PATCH 17/29] rename LocalState to ZalsaLocal --- src/accumulator.rs | 4 ++-- src/database.rs | 12 ++++++------ src/function/fetch.rs | 6 +++--- src/function/maybe_changed_after.rs | 4 ++-- src/function/sync.rs | 4 ++-- src/local_state.rs | 12 ++++++------ src/runtime.rs | 6 +++--- src/zalsa.rs | 4 ++-- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index ea989658..e45ea6e2 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,7 +10,7 @@ use crate::{ hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, - local_state::{LocalState, QueryOrigin}, + local_state::{QueryOrigin, ZalsaLocal}, zalsa::IngredientIndex, Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; @@ -109,7 +109,7 @@ impl IngredientImpl { pub(crate) fn produced_by( &self, current_revision: Revision, - local_state: &LocalState, + local_state: &ZalsaLocal, query: DatabaseKeyIndex, output: &mut Vec, ) { diff --git a/src/database.rs b/src/database.rs index c32e1063..5ddfd2f9 100644 --- a/src/database.rs +++ b/src/database.rs @@ -4,7 +4,7 @@ use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, - local_state::{self, LocalState}, + local_state::{self, ZalsaLocal}, zalsa::Zalsa, Durability, Event, EventKind, Revision, }; @@ -72,7 +72,7 @@ pub unsafe trait Database: Send + AsDynDatabase + Any { /// Access the thread-local state associated with this database #[doc(hidden)] - fn zalsa_local(&self) -> &LocalState; + fn zalsa_local(&self) -> &ZalsaLocal; } /// Upcast to a `dyn Database`. @@ -120,7 +120,7 @@ pub struct DatabaseImpl { coordinate: Arc, /// Per-thread state - zalsa_local: local_state::LocalState, + zalsa_local: local_state::ZalsaLocal, /// The `U` is stored as a `dyn Any` in `zalsa_impl` phantom: PhantomData, @@ -152,7 +152,7 @@ impl DatabaseImpl { clones: Mutex::new(1), cvar: Default::default(), }), - zalsa_local: LocalState::new(), + zalsa_local: ZalsaLocal::new(), phantom: PhantomData::, } } @@ -214,7 +214,7 @@ unsafe impl Database for DatabaseImpl { zalsa_mut } - fn zalsa_local(&self) -> &LocalState { + fn zalsa_local(&self) -> &ZalsaLocal { &self.zalsa_local } @@ -231,7 +231,7 @@ impl Clone for DatabaseImpl { Self { zalsa_impl: self.zalsa_impl.clone(), coordinate: Arc::clone(&self.coordinate), - zalsa_local: LocalState::new(), + zalsa_local: ZalsaLocal::new(), phantom: PhantomData::, } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 4e3018b9..f1ca94f5 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,7 @@ use arc_swap::Guard; use crate::{ - local_state::LocalState, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, + local_state::ZalsaLocal, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, }; use super::{Configuration, IngredientImpl}; @@ -37,7 +37,7 @@ where fn compute_value<'db>( &'db self, db: &'db C::DbView, - local_state: &LocalState, + local_state: &ZalsaLocal, key: Id, ) -> StampedValue<&'db C::Output<'db>> { loop { @@ -75,7 +75,7 @@ where fn fetch_cold<'db>( &'db self, db: &'db C::DbView, - local_state: &LocalState, + local_state: &ZalsaLocal, key: Id, ) -> Option>> { let database_key_index = self.database_key_index(key); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index cb6ec20f..596a1ca9 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,7 +2,7 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, - local_state::{ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, + local_state::{ActiveQueryGuard, EdgeKind, QueryOrigin, ZalsaLocal}, runtime::StampedValue, zalsa::Zalsa, AsDynDatabase as _, Database, Id, Revision, @@ -51,7 +51,7 @@ where fn maybe_changed_after_cold<'db>( &'db self, db: &'db C::DbView, - local_state: &LocalState, + local_state: &ZalsaLocal, key_index: Id, revision: Revision, ) -> Option { diff --git a/src/function/sync.rs b/src/function/sync.rs index 72940083..c7973db4 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, local_state::LocalState, runtime::WaitResult, + hash::FxDashMap, key::DatabaseKeyIndex, local_state::ZalsaLocal, runtime::WaitResult, zalsa::Zalsa, Database, Id, }; @@ -25,7 +25,7 @@ impl SyncMap { pub(super) fn claim<'me>( &'me self, db: &'me dyn Database, - local_state: &LocalState, + local_state: &ZalsaLocal, database_key_index: DatabaseKeyIndex, ) -> Option> { let zalsa = db.zalsa(); diff --git a/src/local_state.rs b/src/local_state.rs index 3cb6652d..3d944858 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -5,8 +5,8 @@ use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::key::DependencyIndex; use crate::runtime::StampedValue; -use crate::zalsa::IngredientIndex; use crate::tracked_struct::Disambiguator; +use crate::zalsa::IngredientIndex; use crate::Cancelled; use crate::Cycle; use crate::Database; @@ -22,7 +22,7 @@ use std::sync::Arc; /// /// **Note also that all mutations to the database handle (and hence /// to the local-state) must be undone during unwinding.** -pub struct LocalState { +pub struct ZalsaLocal { /// Vector of active queries. /// /// This is normally `Some`, but it is set to `None` @@ -33,9 +33,9 @@ pub struct LocalState { query_stack: RefCell>>, } -impl LocalState { +impl ZalsaLocal { pub(crate) fn new() -> Self { - LocalState { + ZalsaLocal { query_stack: RefCell::new(Some(vec![])), } } @@ -262,7 +262,7 @@ impl LocalState { } } -impl std::panic::RefUnwindSafe for LocalState {} +impl std::panic::RefUnwindSafe for ZalsaLocal {} /// Summarizes "all the inputs that a query used" #[derive(Debug, Clone)] @@ -393,7 +393,7 @@ impl QueryEdges { /// the query from the stack -- in the case of unwinding, the guard's /// destructor will also remove the query. pub(crate) struct ActiveQueryGuard<'me> { - local_state: &'me LocalState, + local_state: &'me ZalsaLocal, push_len: usize, pub(crate) database_key_index: DatabaseKeyIndex, } diff --git a/src/runtime.rs b/src/runtime.rs index 3ac4ae8e..bd10291e 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -9,7 +9,7 @@ use parking_lot::Mutex; use crate::{ active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, local_state::LocalState, revision::AtomicRevision, Cancelled, Cycle, + key::DatabaseKeyIndex, local_state::ZalsaLocal, revision::AtomicRevision, Cancelled, Cycle, Database, Event, EventKind, Revision, }; @@ -171,7 +171,7 @@ impl Runtime { pub(crate) fn block_on_or_unwind( &self, db: &dyn Database, - local_state: &LocalState, + local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, @@ -231,7 +231,7 @@ impl Runtime { fn unblock_cycle_and_maybe_throw( &self, db: &dyn Database, - local_state: &LocalState, + local_state: &ZalsaLocal, dg: &mut DependencyGraph, database_key_index: DatabaseKeyIndex, to_id: ThreadId, diff --git a/src/zalsa.rs b/src/zalsa.rs index 292fbd1d..a71f68d3 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; use crate::database::UserData; use crate::ingredient::{Ingredient, Jar}; -use crate::local_state::LocalState; +use crate::local_state::ZalsaLocal; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::{Runtime, WaitResult}; use crate::views::Views; @@ -198,7 +198,7 @@ impl Zalsa { pub(crate) fn block_on_or_unwind( &self, db: &dyn Database, - local_state: &LocalState, + local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, From ab112b7126944bb22fdba6dbb66eebe0247a0035 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 13:12:58 +0000 Subject: [PATCH 18/29] rename local_state to zalsa_local --- src/accumulator.rs | 2 +- src/active_query.rs | 4 ++-- src/database.rs | 4 ++-- src/function.rs | 2 +- src/function/backdate.rs | 2 +- src/function/delete.rs | 2 +- src/function/diff_outputs.rs | 2 +- src/function/execute.rs | 2 +- src/function/fetch.rs | 2 +- src/function/inputs.rs | 2 +- src/function/maybe_changed_after.rs | 2 +- src/function/memo.rs | 4 ++-- src/function/specify.rs | 2 +- src/function/store.rs | 2 +- src/function/sync.rs | 4 ++-- src/ingredient.rs | 2 +- src/input.rs | 2 +- src/input/input_field.rs | 2 +- src/interned.rs | 2 +- src/lib.rs | 2 +- src/runtime.rs | 2 +- src/tracked_struct.rs | 2 +- src/tracked_struct/tracked_field.rs | 2 +- src/zalsa.rs | 2 +- src/{local_state.rs => zalsa_local.rs} | 0 25 files changed, 28 insertions(+), 28 deletions(-) rename src/{local_state.rs => zalsa_local.rs} (100%) diff --git a/src/accumulator.rs b/src/accumulator.rs index e45ea6e2..ac50e6c7 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,7 +10,7 @@ use crate::{ hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, - local_state::{QueryOrigin, ZalsaLocal}, + zalsa_local::{QueryOrigin, ZalsaLocal}, zalsa::IngredientIndex, Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; diff --git a/src/active_query.rs b/src/active_query.rs index 8f575b75..18cea6fc 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -2,12 +2,12 @@ use crate::{ durability::Durability, hash::{FxIndexMap, FxIndexSet}, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::EMPTY_DEPENDENCIES, + zalsa_local::EMPTY_DEPENDENCIES, tracked_struct::Disambiguator, Cycle, Revision, }; -use super::local_state::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; +use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; #[derive(Debug)] pub(crate) struct ActiveQuery { diff --git a/src/database.rs b/src/database.rs index 5ddfd2f9..b541c843 100644 --- a/src/database.rs +++ b/src/database.rs @@ -4,7 +4,7 @@ use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, - local_state::{self, ZalsaLocal}, + zalsa_local::{self, ZalsaLocal}, zalsa::Zalsa, Durability, Event, EventKind, Revision, }; @@ -120,7 +120,7 @@ pub struct DatabaseImpl { coordinate: Arc, /// Per-thread state - zalsa_local: local_state::ZalsaLocal, + zalsa_local: zalsa_local::ZalsaLocal, /// The `U` is stored as a `dyn Any` in `zalsa_impl` phantom: PhantomData, diff --git a/src/function.rs b/src/function.rs index 176d43e9..0f90aa6f 100644 --- a/src/function.rs +++ b/src/function.rs @@ -4,7 +4,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ cycle::CycleRecoveryStrategy, ingredient::fmt_index, key::DatabaseKeyIndex, - local_state::QueryOrigin, salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, + zalsa_local::QueryOrigin, salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, AsDynDatabase as _, Cycle, Database, Event, EventKind, Id, Revision, }; diff --git a/src/function/backdate.rs b/src/function/backdate.rs index 64fa0caf..bfca6f05 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -1,4 +1,4 @@ -use crate::local_state::QueryRevisions; +use crate::zalsa_local::QueryRevisions; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/delete.rs b/src/function/delete.rs index 8c1dc6f2..131c82cf 100644 --- a/src/function/delete.rs +++ b/src/function/delete.rs @@ -1,7 +1,7 @@ use arc_swap::ArcSwap; use crossbeam::queue::SegQueue; -use crate::{local_state::QueryOrigin, Id}; +use crate::{zalsa_local::QueryOrigin, Id}; use super::{memo, Configuration, IngredientImpl}; diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 96df56f7..f365abce 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -1,5 +1,5 @@ use crate::{ - hash::FxHashSet, key::DependencyIndex, local_state::QueryRevisions, AsDynDatabase as _, + hash::FxHashSet, key::DependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _, DatabaseKeyIndex, Event, EventKind, }; diff --git a/src/function/execute.rs b/src/function/execute.rs index cba601d0..0bc2ace6 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - local_state::ActiveQueryGuard, runtime::StampedValue, Cycle, Database, Event, EventKind, + zalsa_local::ActiveQueryGuard, runtime::StampedValue, Cycle, Database, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f1ca94f5..a30fe12a 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,7 @@ use arc_swap::Guard; use crate::{ - local_state::ZalsaLocal, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, + zalsa_local::ZalsaLocal, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, }; use super::{Configuration, IngredientImpl}; diff --git a/src/function/inputs.rs b/src/function/inputs.rs index ae54ca7e..ff5abc58 100644 --- a/src/function/inputs.rs +++ b/src/function/inputs.rs @@ -1,4 +1,4 @@ -use crate::{local_state::QueryOrigin, Id}; +use crate::{zalsa_local::QueryOrigin, Id}; use super::{Configuration, IngredientImpl}; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 596a1ca9..a994f577 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,7 +2,7 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, - local_state::{ActiveQueryGuard, EdgeKind, QueryOrigin, ZalsaLocal}, + zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin, ZalsaLocal}, runtime::StampedValue, zalsa::Zalsa, AsDynDatabase as _, Database, Id, Revision, diff --git a/src/function/memo.rs b/src/function/memo.rs index e84782ca..ac371fe4 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -4,7 +4,7 @@ use arc_swap::{ArcSwap, Guard}; use crossbeam::atomic::AtomicCell; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, local_state::QueryRevisions, zalsa::Zalsa, Event, + hash::FxDashMap, key::DatabaseKeyIndex, zalsa_local::QueryRevisions, zalsa::Zalsa, Event, EventKind, Id, Revision, }; @@ -78,7 +78,7 @@ impl MemoMap { /// with an equivalent memo that has no value. If the memo is untracked, BaseInput, /// or has values assigned as output of another query, this has no effect. pub(super) fn evict(&self, key: Id) { - use crate::local_state::QueryOrigin; + use crate::zalsa_local::QueryOrigin; use dashmap::mapref::entry::Entry::*; if let Occupied(entry) = self.map.entry(key) { diff --git a/src/function/specify.rs b/src/function/specify.rs index 98945dc5..bc598386 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,7 +1,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ - local_state::{QueryOrigin, QueryRevisions}, + zalsa_local::{QueryOrigin, QueryRevisions}, tracked_struct::TrackedStructInDb, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; diff --git a/src/function/store.rs b/src/function/store.rs index aefe1c48..79757106 100644 --- a/src/function/store.rs +++ b/src/function/store.rs @@ -4,7 +4,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ durability::Durability, - local_state::{QueryOrigin, QueryRevisions}, + zalsa_local::{QueryOrigin, QueryRevisions}, Id, Runtime, }; diff --git a/src/function/sync.rs b/src/function/sync.rs index c7973db4..23f6cf9c 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -4,8 +4,8 @@ use std::{ }; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, local_state::ZalsaLocal, runtime::WaitResult, - zalsa::Zalsa, Database, Id, + hash::FxDashMap, key::DatabaseKeyIndex, runtime::WaitResult, zalsa::Zalsa, + zalsa_local::ZalsaLocal, Database, Id, }; #[derive(Default)] diff --git a/src/ingredient.rs b/src/ingredient.rs index 3f053604..c8b26cfa 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - cycle::CycleRecoveryStrategy, local_state::QueryOrigin, zalsa::IngredientIndex, Database, + cycle::CycleRecoveryStrategy, zalsa_local::QueryOrigin, zalsa::IngredientIndex, Database, DatabaseKeyIndex, Id, }; diff --git a/src/input.rs b/src/input.rs index 10c08c62..c884fa8f 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,7 +17,7 @@ use crate::{ id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::QueryOrigin, + zalsa_local::QueryOrigin, plumbing::{Jar, Stamp}, zalsa::IngredientIndex, Database, Durability, Id, Revision, diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 99f6f65a..0eb9baca 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,7 +1,7 @@ use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{fmt_index, Ingredient}; use crate::input::Configuration; -use crate::local_state::QueryOrigin; +use crate::zalsa_local::QueryOrigin; use crate::zalsa::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id, Revision}; use std::fmt; diff --git a/src/interned.rs b/src/interned.rs index e6e76d88..1395696e 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,7 +9,7 @@ use crate::durability::Durability; use crate::id::AsId; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; -use crate::local_state::QueryOrigin; +use crate::zalsa_local::QueryOrigin; use crate::plumbing::Jar; use crate::zalsa::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id}; diff --git a/src/lib.rs b/src/lib.rs index aa018b45..b3eb0a6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,7 @@ mod ingredient_list; mod input; mod interned; mod key; -mod local_state; +mod zalsa_local; mod nonce; mod revision; mod runtime; diff --git a/src/runtime.rs b/src/runtime.rs index bd10291e..5330f7bf 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -9,7 +9,7 @@ use parking_lot::Mutex; use crate::{ active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, local_state::ZalsaLocal, revision::AtomicRevision, Cancelled, Cycle, + key::DatabaseKeyIndex, zalsa_local::ZalsaLocal, revision::AtomicRevision, Cancelled, Cycle, Database, Event, EventKind, Revision, }; diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 1d59ba2b..8fbb0aec 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -11,7 +11,7 @@ use crate::{ ingredient::{fmt_index, Ingredient, Jar}, ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::QueryOrigin, + zalsa_local::QueryOrigin, salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, Database, Durability, Event, Id, Revision, diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 6121d833..07a93c19 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -91,7 +91,7 @@ where field_changed_at > revision } - fn origin(&self, _key_index: crate::Id) -> Option { + fn origin(&self, _key_index: crate::Id) -> Option { None } diff --git a/src/zalsa.rs b/src/zalsa.rs index a71f68d3..ea4ae585 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -8,10 +8,10 @@ use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; use crate::database::UserData; use crate::ingredient::{Ingredient, Jar}; -use crate::local_state::ZalsaLocal; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::{Runtime, WaitResult}; use crate::views::Views; +use crate::zalsa_local::ZalsaLocal; use crate::{Database, DatabaseImpl, DatabaseKeyIndex, Durability, Revision}; pub fn views(db: &Db) -> &Views { diff --git a/src/local_state.rs b/src/zalsa_local.rs similarity index 100% rename from src/local_state.rs rename to src/zalsa_local.rs From 502716d3680e6ae4f558364f962929156dcc1af7 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 21:33:05 +0000 Subject: [PATCH 19/29] pacify the merciless cargo fmt --- src/accumulator.rs | 2 +- src/active_query.rs | 2 +- src/database.rs | 2 +- src/function.rs | 2 +- src/function/execute.rs | 2 +- src/function/fetch.rs | 2 +- src/function/maybe_changed_after.rs | 2 +- src/function/memo.rs | 2 +- src/function/specify.rs | 2 +- src/ingredient.rs | 2 +- src/input.rs | 2 +- src/input/input_field.rs | 2 +- src/interned.rs | 2 +- src/lib.rs | 2 +- src/runtime.rs | 2 +- src/tracked_struct.rs | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index ac50e6c7..d388c8fc 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,8 +10,8 @@ use crate::{ hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, - zalsa_local::{QueryOrigin, ZalsaLocal}, zalsa::IngredientIndex, + zalsa_local::{QueryOrigin, ZalsaLocal}, Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; diff --git a/src/active_query.rs b/src/active_query.rs index 18cea6fc..bc723655 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -2,8 +2,8 @@ use crate::{ durability::Durability, hash::{FxIndexMap, FxIndexSet}, key::{DatabaseKeyIndex, DependencyIndex}, - zalsa_local::EMPTY_DEPENDENCIES, tracked_struct::Disambiguator, + zalsa_local::EMPTY_DEPENDENCIES, Cycle, Revision, }; diff --git a/src/database.rs b/src/database.rs index b541c843..aaa385d9 100644 --- a/src/database.rs +++ b/src/database.rs @@ -4,8 +4,8 @@ use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, - zalsa_local::{self, ZalsaLocal}, zalsa::Zalsa, + zalsa_local::{self, ZalsaLocal}, Durability, Event, EventKind, Revision, }; diff --git a/src/function.rs b/src/function.rs index 0f90aa6f..bd39a4fe 100644 --- a/src/function.rs +++ b/src/function.rs @@ -4,7 +4,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ cycle::CycleRecoveryStrategy, ingredient::fmt_index, key::DatabaseKeyIndex, - zalsa_local::QueryOrigin, salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, + salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, zalsa_local::QueryOrigin, AsDynDatabase as _, Cycle, Database, Event, EventKind, Id, Revision, }; diff --git a/src/function/execute.rs b/src/function/execute.rs index 0bc2ace6..4d4fd821 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - zalsa_local::ActiveQueryGuard, runtime::StampedValue, Cycle, Database, Event, EventKind, + runtime::StampedValue, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index a30fe12a..545bd9f8 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,7 @@ use arc_swap::Guard; use crate::{ - zalsa_local::ZalsaLocal, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, + runtime::StampedValue, zalsa_local::ZalsaLocal, AsDynDatabase as _, Database as _, Id, }; use super::{Configuration, IngredientImpl}; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index a994f577..8c9a4b1a 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,9 +2,9 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, - zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin, ZalsaLocal}, runtime::StampedValue, zalsa::Zalsa, + zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin, ZalsaLocal}, AsDynDatabase as _, Database, Id, Revision, }; diff --git a/src/function/memo.rs b/src/function/memo.rs index ac371fe4..22b97bdc 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -4,7 +4,7 @@ use arc_swap::{ArcSwap, Guard}; use crossbeam::atomic::AtomicCell; use crate::{ - hash::FxDashMap, key::DatabaseKeyIndex, zalsa_local::QueryRevisions, zalsa::Zalsa, Event, + hash::FxDashMap, key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, Revision, }; diff --git a/src/function/specify.rs b/src/function/specify.rs index bc598386..a1624e0e 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,8 +1,8 @@ use crossbeam::atomic::AtomicCell; use crate::{ - zalsa_local::{QueryOrigin, QueryRevisions}, tracked_struct::TrackedStructInDb, + zalsa_local::{QueryOrigin, QueryRevisions}, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; diff --git a/src/ingredient.rs b/src/ingredient.rs index c8b26cfa..85594e86 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - cycle::CycleRecoveryStrategy, zalsa_local::QueryOrigin, zalsa::IngredientIndex, Database, + cycle::CycleRecoveryStrategy, zalsa::IngredientIndex, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, }; diff --git a/src/input.rs b/src/input.rs index c884fa8f..5ab315f9 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,9 +17,9 @@ use crate::{ id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, - zalsa_local::QueryOrigin, plumbing::{Jar, Stamp}, zalsa::IngredientIndex, + zalsa_local::QueryOrigin, Database, Durability, Id, Revision, }; diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 0eb9baca..c6f628d0 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,8 +1,8 @@ use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{fmt_index, Ingredient}; use crate::input::Configuration; -use crate::zalsa_local::QueryOrigin; use crate::zalsa::IngredientIndex; +use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id, Revision}; use std::fmt; diff --git a/src/interned.rs b/src/interned.rs index 1395696e..65e56437 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,9 +9,9 @@ use crate::durability::Durability; use crate::id::AsId; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; -use crate::zalsa_local::QueryOrigin; use crate::plumbing::Jar; use crate::zalsa::IngredientIndex; +use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; use super::hash::FxDashMap; diff --git a/src/lib.rs b/src/lib.rs index b3eb0a6b..d64cebe4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ mod ingredient_list; mod input; mod interned; mod key; -mod zalsa_local; mod nonce; mod revision; mod runtime; @@ -25,6 +24,7 @@ mod tracked_struct; mod update; mod views; mod zalsa; +mod zalsa_local; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; diff --git a/src/runtime.rs b/src/runtime.rs index 5330f7bf..02f8a1d7 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -9,7 +9,7 @@ use parking_lot::Mutex; use crate::{ active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, zalsa_local::ZalsaLocal, revision::AtomicRevision, Cancelled, Cycle, + key::DatabaseKeyIndex, revision::AtomicRevision, zalsa_local::ZalsaLocal, Cancelled, Cycle, Database, Event, EventKind, Revision, }; diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 8fbb0aec..b0eb049e 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -11,9 +11,9 @@ use crate::{ ingredient::{fmt_index, Ingredient, Jar}, ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, - zalsa_local::QueryOrigin, salsa_struct::SalsaStructInDb, zalsa::IngredientIndex, + zalsa_local::QueryOrigin, Database, Durability, Event, Id, Revision, }; From 4995ce0ddcf82809c7614ba25406f26c518d1d11 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 29 Jul 2024 09:58:57 +0200 Subject: [PATCH 20/29] Relax dependency constraints --- Cargo.toml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 611aa69b..762d5b54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,20 +8,20 @@ repository = "https://github.com/salsa-rs/salsa" description = "A generic framework for on-demand, incrementalized computation (experimental)" [dependencies] -arc-swap = "1.6.0" -boomphf = "0.6.0" -crossbeam = "0.8.1" -dashmap = "6.0.1" -hashlink = "0.9.1" +arc-swap = "1" +boomphf = "0.6" +crossbeam = "0.8" +dashmap = "6" +hashlink = "0.9" indexmap = "2" -orx-concurrent-vec = "2.2.0" +orx-concurrent-vec = "2" tracing = "0.1" -parking_lot = "0.12.1" -rustc-hash = "2.0.0" +parking_lot = "0.12" +rustc-hash = "2" salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macros = { path = "components/salsa-macros" } -smallvec = "1.0.0" -lazy_static = "1.5.0" +smallvec = "1" +lazy_static = "1" [dev-dependencies] annotate-snippets = "0.11.4" From 12e0741252701c494ebd28050f929f81a30a9438 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 29 Jul 2024 10:25:26 +0200 Subject: [PATCH 21/29] Implement DerefMut --- src/database.rs | 9 +++++++++ src/zalsa.rs | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/src/database.rs b/src/database.rs index aaa385d9..20781001 100644 --- a/src/database.rs +++ b/src/database.rs @@ -196,6 +196,15 @@ impl std::ops::Deref for DatabaseImpl { } } +impl std::ops::DerefMut for DatabaseImpl { + fn deref_mut(&mut self) -> &mut U { + self.zalsa_mut() + .user_data_mut() + .downcast_mut::() + .unwrap() + } +} + impl RefUnwindSafe for DatabaseImpl {} #[salsa_macros::db] diff --git a/src/zalsa.rs b/src/zalsa.rs index ea4ae585..f18930f3 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -179,6 +179,10 @@ impl Zalsa { &*self.user_data } + pub(crate) fn user_data_mut(&mut self) -> &mut (dyn Any + Send + Sync) { + &mut *self.user_data + } + /// Triggers a new revision. Invoked automatically when you call `zalsa_mut` /// and so doesn't need to be called otherwise. pub(crate) fn new_revision(&mut self) -> Revision { From 4d2ccffddc2f168316c42e8031b710baaffadfd0 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 4 Aug 2024 01:33:58 -0400 Subject: [PATCH 22/29] return to the database-wrapping-storage setup --- components/salsa-macros/src/db.rs | 52 ++++- examples/calc/db.rs | 18 +- examples/lazy-input/main.rs | 20 +- src/database.rs | 209 +----------------- src/database_impl.rs | 24 ++ src/function/accumulated.rs | 2 +- src/function/execute.rs | 3 +- src/function/fetch.rs | 2 +- src/function/maybe_changed_after.rs | 4 +- src/function/specify.rs | 1 + src/lib.rs | 11 +- src/storage.rs | 137 ++++++++++++ src/zalsa.rs | 44 ++-- tests/accumulate-reuse-workaround.rs | 6 +- tests/accumulate-reuse.rs | 6 +- tests/accumulate.rs | 10 +- tests/common/mod.rs | 82 ++++--- tests/deletion-cascade.rs | 6 +- tests/deletion.rs | 2 +- tests/elided-lifetime-in-tracked-fn.rs | 6 +- ...truct_changes_but_fn_depends_on_field_y.rs | 6 +- ...input_changes_but_fn_depends_on_field_y.rs | 4 +- tests/hello_world.rs | 6 +- tests/lru.rs | 10 +- tests/parallel/parallel_cancellation.rs | 5 +- tests/parallel/parallel_cycle_all_recover.rs | 6 +- tests/parallel/parallel_cycle_mid_recover.rs | 6 +- tests/parallel/parallel_cycle_none_recover.rs | 5 +- tests/parallel/parallel_cycle_one_recover.rs | 6 +- tests/parallel/setup.rs | 37 +++- tests/preverify-struct-with-leaked-data.rs | 4 +- tests/synthetic_write.rs | 4 +- tests/tracked-struct-value-field-bad-eq.rs | 2 +- tests/tracked_fn_no_eq.rs | 6 +- tests/tracked_fn_read_own_entity.rs | 6 +- tests/tracked_fn_read_own_specify.rs | 4 +- 36 files changed, 415 insertions(+), 347 deletions(-) create mode 100644 src/database_impl.rs create mode 100644 src/storage.rs diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index e70ad745..02fcef4b 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -32,6 +32,13 @@ struct DbMacro { impl DbMacro { fn try_db(self, input: syn::Item) -> syn::Result { match input { + syn::Item::Struct(input) => { + let has_storage_impl = self.has_storage_impl(&input)?; + Ok(quote! { + #has_storage_impl + #input + }) + } syn::Item::Trait(mut input) => { self.add_salsa_view_method(&mut input)?; Ok(quote! { @@ -46,11 +53,54 @@ impl DbMacro { } _ => Err(syn::Error::new_spanned( input, - "`db` must be applied to a trait or impl", + "`db` must be applied to a struct, trait, or impl", )), } } + fn find_storage_field(&self, input: &syn::ItemStruct) -> syn::Result { + let storage = "storage"; + for field in input.fields.iter() { + if let Some(i) = &field.ident { + if i == storage { + return Ok(i.clone()); + } + } else { + return Err(syn::Error::new_spanned( + field, + "database struct must be a braced struct (`{}`) with a field named `storage`", + )); + } + } + + Err(syn::Error::new_spanned( + &input.ident, + "database struct must be a braced struct (`{}`) with a field named `storage`", + )) + } + + fn has_storage_impl(&self, input: &syn::ItemStruct) -> syn::Result { + let storage = self.find_storage_field(input)?; + let db = &input.ident; + let zalsa = self.hygiene.ident("zalsa"); + + Ok(quote! { + const _: () = { + use salsa::plumbing as #zalsa; + + unsafe impl #zalsa::HasStorage for #db { + fn storage(&self) -> &#zalsa::Storage { + &self.#storage + } + + fn storage_mut(&mut self) -> &mut #zalsa::Storage { + &mut self.#storage + } + } + }; + }) + } + fn add_salsa_view_method(&self, input: &mut syn::ItemTrait) -> syn::Result<()> { input.items.push(parse_quote! { #[doc(hidden)] diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 5f1d98f7..2873ed5b 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -1,18 +1,17 @@ use std::sync::{Arc, Mutex}; -use salsa::UserData; - -pub type CalcDatabaseImpl = salsa::DatabaseImpl; - // ANCHOR: db_struct +#[salsa::db] #[derive(Default)] -pub struct Calc { +pub struct CalcDatabaseImpl { + storage: salsa::Storage, + // The logs are only used for testing and demonstrating reuse: logs: Arc>>>, } // ANCHOR_END: db_struct -impl Calc { +impl CalcDatabaseImpl { /// Enable logging of each salsa event. #[cfg(test)] pub fn enable_logging(&self) { @@ -34,12 +33,13 @@ impl Calc { } // ANCHOR: db_impl -impl UserData for Calc { - fn salsa_event(db: &CalcDatabaseImpl, event: &dyn Fn() -> salsa::Event) { +#[salsa::db] +impl salsa::Database for CalcDatabaseImpl { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { let event = event(); eprintln!("Event: {event:?}"); // Log interesting events, if logging is enabled - if let Some(logs) = &mut *db.logs.lock().unwrap() { + if let Some(logs) = &mut *self.logs.lock().unwrap() { // only log interesting events if let salsa::EventKind::WillExecute { .. } = event.kind { logs.push(format!("Event: {event:?}")); diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index 3891367b..3b918888 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -8,13 +8,13 @@ use notify_debouncer_mini::{ notify::{RecommendedWatcher, RecursiveMode}, DebounceEventResult, Debouncer, }; -use salsa::{Accumulator, DatabaseImpl, Setter, UserData}; +use salsa::{Accumulator, Setter, Storage}; // ANCHOR: main fn main() -> Result<()> { // Create the channel to receive file change events. let (tx, rx) = unbounded(); - let mut db = DatabaseImpl::with(LazyInput::new(tx)); + let mut db = LazyInputDatabase::new(tx); let initial_file_path = std::env::args_os() .nth(1) @@ -74,15 +74,18 @@ trait Db: salsa::Database { fn input(&self, path: PathBuf) -> Result; } -struct LazyInput { +#[salsa::db] +struct LazyInputDatabase { + storage: Storage, logs: Mutex>, files: DashMap, file_watcher: Mutex>, } -impl LazyInput { +impl LazyInputDatabase { fn new(tx: Sender) -> Self { Self { + storage: Default::default(), logs: Default::default(), files: DashMap::new(), file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()), @@ -90,18 +93,19 @@ impl LazyInput { } } -impl UserData for LazyInput { - fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { +#[salsa::db] +impl salsa::Database for LazyInputDatabase { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { // don't log boring events let event = event(); if let salsa::EventKind::WillExecute { .. } = event.kind { - db.logs.lock().unwrap().push(format!("{:?}", event)); + self.logs.lock().unwrap().push(format!("{:?}", event)); } } } #[salsa::db] -impl Db for DatabaseImpl { +impl Db for LazyInputDatabase { fn input(&self, path: PathBuf) -> Result { let path = path .canonicalize() diff --git a/src/database.rs b/src/database.rs index 20781001..801e2cd8 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,24 +1,11 @@ -use std::{any::Any, marker::PhantomData, panic::RefUnwindSafe, sync::Arc}; +use std::any::Any; -use parking_lot::{Condvar, Mutex}; - -use crate::{ - self as salsa, - zalsa::Zalsa, - zalsa_local::{self, ZalsaLocal}, - Durability, Event, EventKind, Revision, -}; +use crate::{zalsa::ZalsaDatabase, Durability, Event, Revision}; /// The trait implemented by all Salsa databases. -/// You can create your own subtraits of this trait using the `#[salsa::db]` procedural macro. -/// -/// # Safety -/// -/// This trait can only safely be implemented by Salsa's [`DatabaseImpl`][] type. -/// -/// FIXME: Document better the unsafety conditions we require. -#[salsa_macros::db] -pub unsafe trait Database: Send + AsDynDatabase + Any { +/// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro. +#[crate::db] +pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { /// This function is invoked by the salsa runtime at various points during execution. /// You can customize what happens by implementing the [`UserData`][] trait. /// By default, the event is logged at level debug using tracing facade. @@ -58,21 +45,6 @@ pub unsafe trait Database: Send + AsDynDatabase + Any { { crate::attach::attach(self, || op(self)) } - - /// Plumbing method: Access the internal salsa methods. - #[doc(hidden)] - fn zalsa(&self) -> &Zalsa; - - /// Plumbing method: Access the internal salsa methods for mutating the database. - /// - /// **WARNING:** Triggers a new revision, canceling other database handles. - /// This can lead to deadlock! - #[doc(hidden)] - fn zalsa_mut(&mut self) -> &mut Zalsa; - - /// Access the thread-local state associated with this database - #[doc(hidden)] - fn zalsa_local(&self) -> &ZalsaLocal; } /// Upcast to a `dyn Database`. @@ -108,174 +80,3 @@ impl dyn Database { self.zalsa().views().try_view_as(self).unwrap() } } - -/// Concrete implementation of the [`Database`][] trait. -/// Takes an optional type parameter `U` that allows you to thread your own data. -pub struct DatabaseImpl { - /// Reference to the database. This is always `Some` except during destruction. - zalsa_impl: Option>, - - /// Coordination data for cancellation of other handles when `zalsa_mut` is called. - /// This could be stored in Zalsa but it makes things marginally cleaner to keep it separate. - coordinate: Arc, - - /// Per-thread state - zalsa_local: zalsa_local::ZalsaLocal, - - /// The `U` is stored as a `dyn Any` in `zalsa_impl` - phantom: PhantomData, -} - -impl Default for DatabaseImpl { - fn default() -> Self { - Self::with(U::default()) - } -} - -impl DatabaseImpl<()> { - /// Create a new database with the given user data. - /// - /// You can also use the [`Default`][] trait if your userdata implements it. - pub fn new() -> Self { - Self::with(()) - } -} - -impl DatabaseImpl { - /// Create a new database with the given user data. - /// - /// You can also use the [`Default`][] trait if your userdata implements it. - pub fn with(u: U) -> Self { - Self { - zalsa_impl: Some(Arc::new(Zalsa::with(u))), - coordinate: Arc::new(Coordinate { - clones: Mutex::new(1), - cvar: Default::default(), - }), - zalsa_local: ZalsaLocal::new(), - phantom: PhantomData::, - } - } - - /// Access the `Arc`. This should always be - /// possible as `zalsa_impl` only becomes - /// `None` once we are in the `Drop` impl. - fn zalsa_impl(&self) -> &Arc { - self.zalsa_impl.as_ref().unwrap() - } - - // ANCHOR: cancel_other_workers - /// Sets cancellation flag and blocks until all other workers with access - /// to this storage have completed. - /// - /// This could deadlock if there is a single worker with two handles to the - /// same database! - fn cancel_others(&mut self) { - let zalsa = self.zalsa_impl(); - zalsa.set_cancellation_flag(); - - self.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - - kind: EventKind::DidSetCancellationFlag, - }); - - let mut clones = self.coordinate.clones.lock(); - while *clones != 1 { - self.coordinate.cvar.wait(&mut clones); - } - } - // ANCHOR_END: cancel_other_workers -} - -impl std::ops::Deref for DatabaseImpl { - type Target = U; - - fn deref(&self) -> &U { - self.zalsa_impl().user_data().downcast_ref::().unwrap() - } -} - -impl std::ops::DerefMut for DatabaseImpl { - fn deref_mut(&mut self) -> &mut U { - self.zalsa_mut() - .user_data_mut() - .downcast_mut::() - .unwrap() - } -} - -impl RefUnwindSafe for DatabaseImpl {} - -#[salsa_macros::db] -unsafe impl Database for DatabaseImpl { - fn zalsa(&self) -> &Zalsa { - &**self.zalsa_impl() - } - - fn zalsa_mut(&mut self) -> &mut Zalsa { - self.cancel_others(); - - // The ref count on the `Arc` should now be 1 - let arc_zalsa_mut = self.zalsa_impl.as_mut().unwrap(); - let zalsa_mut = Arc::get_mut(arc_zalsa_mut).unwrap(); - zalsa_mut.new_revision(); - zalsa_mut - } - - fn zalsa_local(&self) -> &ZalsaLocal { - &self.zalsa_local - } - - // Report a salsa event. - fn salsa_event(&self, event: &dyn Fn() -> Event) { - U::salsa_event(self, event) - } -} - -impl Clone for DatabaseImpl { - fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - - Self { - zalsa_impl: self.zalsa_impl.clone(), - coordinate: Arc::clone(&self.coordinate), - zalsa_local: ZalsaLocal::new(), - phantom: PhantomData::, - } - } -} - -impl Drop for DatabaseImpl { - fn drop(&mut self) { - // Drop the database handle *first* - self.zalsa_impl.take(); - - // *Now* decrement the number of clones and notify once we have completed - *self.coordinate.clones.lock() -= 1; - self.coordinate.cvar.notify_all(); - } -} - -pub trait UserData: Any + Sized + Send + Sync { - /// Callback invoked by the [`Database`][] at key points during salsa execution. - /// By overriding this method, you can inject logging or other custom behavior. - /// - /// By default, the event is logged at level debug using the `tracing` crate. - /// - /// # Parameters - /// - /// * `event` a fn that, if called, will return the event that occurred - fn salsa_event(_db: &DatabaseImpl, event: &dyn Fn() -> Event) { - tracing::debug!("salsa_event: {:?}", event()) - } -} - -impl UserData for () {} - -struct Coordinate { - /// Counter of the number of clones of actor. Begins at 1. - /// Incremented when cloned, decremented when dropped. - clones: Mutex, - cvar: Condvar, -} diff --git a/src/database_impl.rs b/src/database_impl.rs new file mode 100644 index 00000000..71da9fff --- /dev/null +++ b/src/database_impl.rs @@ -0,0 +1,24 @@ +use crate::{self as salsa, Database, Event, Storage}; + +#[salsa::db] +/// Default database implementation that you can use if you don't +/// require any custom user data. +#[derive(Default)] +pub struct DatabaseImpl { + storage: Storage, +} + +impl DatabaseImpl { + /// Create a new database; equivalent to `Self::default`. + pub fn new() -> Self { + Self::default() + } +} + +#[salsa::db] +impl Database for DatabaseImpl { + /// Default behavior: tracing debug log the event. + fn salsa_event(&self, event: &dyn Fn() -> Event) { + tracing::debug!("salsa_event({:?})", event()); + } +} diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 62c85930..56c4598f 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,4 @@ -use crate::{accumulator, hash::FxHashSet, Database, DatabaseKeyIndex, Id}; +use crate::{accumulator, hash::FxHashSet, zalsa::ZalsaDatabase, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; diff --git a/src/function/execute.rs b/src/function/execute.rs index 4d4fd821..ddf58736 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use crate::{ - runtime::StampedValue, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, + runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, + Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 545bd9f8..2bf87846 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,7 @@ use arc_swap::Guard; use crate::{ - runtime::StampedValue, zalsa_local::ZalsaLocal, AsDynDatabase as _, Database as _, Id, + runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::ZalsaLocal, AsDynDatabase as _, Id, }; use super::{Configuration, IngredientImpl}; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 8c9a4b1a..e81119e8 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -3,9 +3,9 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, runtime::StampedValue, - zalsa::Zalsa, + zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin, ZalsaLocal}, - AsDynDatabase as _, Database, Id, Revision, + AsDynDatabase as _, Id, Revision, }; use super::{memo::Memo, Configuration, IngredientImpl}; diff --git a/src/function/specify.rs b/src/function/specify.rs index a1624e0e..94d7a287 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -2,6 +2,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ tracked_struct::TrackedStructInDb, + zalsa::ZalsaDatabase, zalsa_local::{QueryOrigin, QueryRevisions}, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; diff --git a/src/lib.rs b/src/lib.rs index d64cebe4..d1b22e70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod attach; mod cancelled; mod cycle; mod database; +mod database_impl; mod durability; mod event; mod function; @@ -20,6 +21,7 @@ mod nonce; mod revision; mod runtime; mod salsa_struct; +mod storage; mod tracked_struct; mod update; mod views; @@ -31,8 +33,7 @@ pub use self::cancelled::Cancelled; pub use self::cycle::Cycle; pub use self::database::AsDynDatabase; pub use self::database::Database; -pub use self::database::DatabaseImpl; -pub use self::database::UserData; +pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; pub use self::event::Event; pub use self::event::EventKind; @@ -41,6 +42,7 @@ pub use self::input::setter::Setter; pub use self::key::DatabaseKeyIndex; pub use self::revision::Revision; pub use self::runtime::Runtime; +pub use self::storage::Storage; pub use self::update::Update; pub use crate::attach::with_attached_database; pub use salsa_macros::accumulator; @@ -70,7 +72,6 @@ pub mod plumbing { pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; - pub use crate::database::UserData; pub use crate::function::should_backdate_value; pub use crate::id::AsId; pub use crate::id::FromId; @@ -85,6 +86,8 @@ pub mod plumbing { pub use crate::runtime::Stamp; pub use crate::runtime::StampedValue; pub use crate::salsa_struct::SalsaStructInDb; + pub use crate::storage::HasStorage; + pub use crate::storage::Storage; pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::always_update; pub use crate::update::helper::Dispatch as UpdateDispatch; @@ -94,6 +97,8 @@ pub mod plumbing { pub use crate::zalsa::IngredientCache; pub use crate::zalsa::IngredientIndex; pub use crate::zalsa::Zalsa; + pub use crate::zalsa::ZalsaDatabase; + pub use crate::zalsa_local::ZalsaLocal; pub use salsa_macro_rules::macro_if; pub use salsa_macro_rules::maybe_backdate; diff --git a/src/storage.rs b/src/storage.rs new file mode 100644 index 00000000..c9e1273b --- /dev/null +++ b/src/storage.rs @@ -0,0 +1,137 @@ +use std::{marker::PhantomData, panic::RefUnwindSafe, sync::Arc}; + +use parking_lot::{Condvar, Mutex}; + +use crate::{ + zalsa::{Zalsa, ZalsaDatabase}, + zalsa_local::{self, ZalsaLocal}, + Database, Event, EventKind, +}; + +/// Access the "storage" of a Salsa database: this is an internal plumbing trait +/// automatically implemented by `#[salsa::db]` applied to a struct. +/// +/// # Safety +/// +/// The `storage` and `storage_mut` fields must both return a reference to the same +/// storage field which must be owned by `self`. +pub unsafe trait HasStorage: Database + Sized { + fn storage(&self) -> &Storage; + fn storage_mut(&mut self) -> &mut Storage; +} + +/// Concrete implementation of the [`Database`][] trait. +/// Takes an optional type parameter `U` that allows you to thread your own data. +pub struct Storage { + /// Reference to the database. This is always `Some` except during destruction. + zalsa_impl: Option>, + + /// Coordination data for cancellation of other handles when `zalsa_mut` is called. + /// This could be stored in Zalsa but it makes things marginally cleaner to keep it separate. + coordinate: Arc, + + /// Per-thread state + zalsa_local: zalsa_local::ZalsaLocal, + + /// We store references to `Db` + phantom: PhantomData Db>, +} +struct Coordinate { + /// Counter of the number of clones of actor. Begins at 1. + /// Incremented when cloned, decremented when dropped. + clones: Mutex, + cvar: Condvar, +} + +impl Default for Storage { + fn default() -> Self { + Self { + zalsa_impl: Some(Arc::new(Zalsa::new::())), + coordinate: Arc::new(Coordinate { + clones: Mutex::new(1), + cvar: Default::default(), + }), + zalsa_local: ZalsaLocal::new(), + phantom: PhantomData, + } + } +} + +impl Storage { + /// Access the `Arc`. This should always be + /// possible as `zalsa_impl` only becomes + /// `None` once we are in the `Drop` impl. + fn zalsa_impl(&self) -> &Arc { + self.zalsa_impl.as_ref().unwrap() + } + + // ANCHOR: cancel_other_workers + /// Sets cancellation flag and blocks until all other workers with access + /// to this storage have completed. + /// + /// This could deadlock if there is a single worker with two handles to the + /// same database! + fn cancel_others(&self, db: &Db) { + let zalsa = self.zalsa_impl(); + zalsa.set_cancellation_flag(); + + db.salsa_event(&|| Event { + thread_id: std::thread::current().id(), + + kind: EventKind::DidSetCancellationFlag, + }); + + let mut clones = self.coordinate.clones.lock(); + while *clones != 1 { + self.coordinate.cvar.wait(&mut clones); + } + } + // ANCHOR_END: cancel_other_workers +} + +unsafe impl ZalsaDatabase for T { + fn zalsa(&self) -> &Zalsa { + self.storage().zalsa_impl.as_ref().unwrap() + } + + fn zalsa_mut(&mut self) -> &mut Zalsa { + self.storage().cancel_others(self); + + // The ref count on the `Arc` should now be 1 + let storage = self.storage_mut(); + let arc_zalsa_mut = storage.zalsa_impl.as_mut().unwrap(); + let zalsa_mut = Arc::get_mut(arc_zalsa_mut).unwrap(); + zalsa_mut.new_revision(); + zalsa_mut + } + + fn zalsa_local(&self) -> &ZalsaLocal { + &self.storage().zalsa_local + } +} + +impl RefUnwindSafe for Storage {} + +impl Clone for Storage { + fn clone(&self) -> Self { + *self.coordinate.clones.lock() += 1; + + Self { + zalsa_impl: self.zalsa_impl.clone(), + coordinate: Arc::clone(&self.coordinate), + zalsa_local: ZalsaLocal::new(), + phantom: PhantomData, + } + } +} + +impl Drop for Storage { + fn drop(&mut self) { + // Drop the database handle *first* + self.zalsa_impl.take(); + + // *Now* decrement the number of clones and notify once we have completed + *self.coordinate.clones.lock() -= 1; + self.coordinate.cvar.notify_all(); + } +} diff --git a/src/zalsa.rs b/src/zalsa.rs index f18930f3..0271b91e 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -6,13 +6,38 @@ use parking_lot::Mutex; use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; -use crate::database::UserData; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::{Runtime, WaitResult}; use crate::views::Views; use crate::zalsa_local::ZalsaLocal; -use crate::{Database, DatabaseImpl, DatabaseKeyIndex, Durability, Revision}; +use crate::{Database, DatabaseKeyIndex, Durability, Revision}; + +/// Internal plumbing trait; implemented automatically when `#[salsa::db]`(`crate::db`) is attached to your database struct. +/// Contains methods that give access to the internal data from the `storage` field. +/// +/// # Safety +/// +/// The system assumes this is implemented by a salsa procedural macro +/// which makes use of private data from the [`Storage`](`crate::storage::Storage`) struct. +/// Do not implement this yourself, instead, apply the [`salsa::db`](`crate::db`) macro +/// to your database. +pub unsafe trait ZalsaDatabase: Any { + /// Plumbing method: Access the internal salsa methods. + #[doc(hidden)] + fn zalsa(&self) -> &Zalsa; + + /// Plumbing method: Access the internal salsa methods for mutating the database. + /// + /// **WARNING:** Triggers a new revision, canceling other database handles. + /// This can lead to deadlock! + #[doc(hidden)] + fn zalsa_mut(&mut self) -> &mut Zalsa; + + /// Access the thread-local state associated with this database + #[doc(hidden)] + fn zalsa_local(&self) -> &ZalsaLocal; +} pub fn views(db: &Db) -> &Views { db.zalsa().views() @@ -61,8 +86,6 @@ impl IngredientIndex { /// /// **NOT SEMVER STABLE.** pub struct Zalsa { - user_data: Box, - views_of: Views, nonce: Nonce, @@ -89,15 +112,14 @@ pub struct Zalsa { } impl Zalsa { - pub(crate) fn with(user_data: U) -> Self { + pub(crate) fn new() -> Self { Self { - views_of: Views::new::>(), + views_of: Views::new::(), nonce: NONCE.nonce(), jar_map: Default::default(), ingredients_vec: Default::default(), ingredients_requiring_reset: Default::default(), runtime: Runtime::default(), - user_data: Box::new(user_data), } } @@ -175,14 +197,6 @@ impl Zalsa { self.runtime.set_cancellation_flag() } - pub(crate) fn user_data(&self) -> &(dyn Any + Send + Sync) { - &*self.user_data - } - - pub(crate) fn user_data_mut(&mut self) -> &mut (dyn Any + Send + Sync) { - &mut *self.user_data - } - /// Triggers a new revision. Invoked automatically when you call `zalsa_mut` /// and so doesn't need to be called otherwise. pub(crate) fn new_revision(&mut self) -> Revision { diff --git a/tests/accumulate-reuse-workaround.rs b/tests/accumulate-reuse-workaround.rs index f43c098e..d72f971b 100644 --- a/tests/accumulate-reuse-workaround.rs +++ b/tests/accumulate-reuse-workaround.rs @@ -3,10 +3,10 @@ //! reuse. mod common; -use common::{LogDatabase, Logger}; +use common::{LogDatabase, LoggerDatabase}; use expect_test::expect; -use salsa::{Accumulator, DatabaseImpl, Setter}; +use salsa::{Accumulator, Setter}; use test_log::test; #[salsa::input] @@ -49,7 +49,7 @@ fn accumulated(db: &dyn LogDatabase, input: List) -> Vec { #[test] fn test1() { - let mut db: DatabaseImpl = DatabaseImpl::default(); + let mut db = LoggerDatabase::default(); let l1 = List::new(&db, 1, None); let l2 = List::new(&db, 2, Some(l1)); diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index e9ac47b3..b9962849 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -4,10 +4,10 @@ //! are the accumulated values from another query. mod common; -use common::{LogDatabase, Logger}; +use common::{LogDatabase, LoggerDatabase}; use expect_test::expect; -use salsa::{prelude::*, DatabaseImpl}; +use salsa::{Accumulator, Setter}; use test_log::test; #[salsa::input] @@ -40,7 +40,7 @@ fn compute(db: &dyn LogDatabase, input: List) -> u32 { #[test] fn test1() { - let mut db = DatabaseImpl::with(Logger::default()); + let mut db = LoggerDatabase::default(); let l1 = List::new(&db, 1, None); let l2 = List::new(&db, 2, Some(l1)); diff --git a/tests/accumulate.rs b/tests/accumulate.rs index ea16666d..a69c27a1 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -1,5 +1,5 @@ mod common; -use common::{LogDatabase, Logger}; +use common::{LogDatabase, LoggerDatabase}; use expect_test::expect; use salsa::{Accumulator, Setter}; @@ -55,7 +55,7 @@ fn push_b_logs(db: &dyn LogDatabase, input: MyInput) { #[test] fn accumulate_once() { - let db = salsa::DatabaseImpl::with(Logger::default()); + let db = common::LoggerDatabase::default(); // Just call accumulate on a base input to see what happens. let input = MyInput::new(&db, 2, 3); @@ -91,7 +91,7 @@ fn accumulate_once() { #[test] fn change_a_from_2_to_0() { - let mut db = salsa::DatabaseImpl::with(Logger::default()); + let mut db = common::LoggerDatabase::default(); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); @@ -146,7 +146,7 @@ fn change_a_from_2_to_0() { #[test] fn change_a_from_2_to_1() { - let mut db = salsa::DatabaseImpl::with(Logger::default()); + let mut db = LoggerDatabase::default(); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); @@ -205,7 +205,7 @@ fn change_a_from_2_to_1() { #[test] fn get_a_logs_after_changing_b() { - let mut db = salsa::DatabaseImpl::with(Logger::default()); + let mut db = common::LoggerDatabase::default(); // Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter) let input = MyInput::new(&db, 2, 3); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 752d9e4c..4c4e9fc7 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,7 +2,7 @@ #![allow(dead_code)] -use salsa::{DatabaseImpl, UserData}; +use salsa::{Database, Storage}; /// Logging userdata: provides [`LogDatabase`][] trait. /// @@ -13,10 +13,14 @@ pub struct Logger { logs: std::sync::Mutex>, } -impl UserData for Logger {} +/// Trait implemented by databases that lets them log events. +pub trait HasLogger { + /// Return a reference to the logger from the database. + fn logger(&self) -> &Logger; +} #[salsa::db] -pub trait LogDatabase: HasLogger + salsa::Database { +pub trait LogDatabase: HasLogger + Database { /// Log an event from inside a tracked function. fn push_log(&self, string: String) { self.logger().logs.lock().unwrap().push(string); @@ -40,84 +44,98 @@ pub trait LogDatabase: HasLogger + salsa::Database { } #[salsa::db] -impl LogDatabase for DatabaseImpl {} +impl LogDatabase for Db {} -/// Trait implemented by databases that lets them log events. -pub trait HasLogger { - /// Return a reference to the logger from the database. - fn logger(&self) -> &Logger; +/// Database that provides logging but does not log salsa event. +#[salsa::db] +#[derive(Default)] +pub struct LoggerDatabase { + storage: Storage, + logger: Logger, } -impl HasLogger for DatabaseImpl { +impl HasLogger for LoggerDatabase { fn logger(&self) -> &Logger { - U::logger(self) + &self.logger } } -impl HasLogger for Logger { - fn logger(&self) -> &Logger { - self - } +#[salsa::db] +impl Database for LoggerDatabase { + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} } -/// Userdata that provides logging and logs salsa events. +/// Database that provides logging and logs salsa events. +#[salsa::db] #[derive(Default)] -pub struct EventLogger { +pub struct EventLoggerDatabase { + storage: Storage, logger: Logger, } -impl UserData for EventLogger { - fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { - db.push_log(format!("{:?}", event())); +#[salsa::db] +impl Database for EventLoggerDatabase { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + self.push_log(format!("{:?}", event())); } } -impl HasLogger for EventLogger { +impl HasLogger for EventLoggerDatabase { fn logger(&self) -> &Logger { &self.logger } } +#[salsa::db] #[derive(Default)] -pub struct DiscardLogger(Logger); +pub struct DiscardLoggerDatabase { + storage: Storage, + logger: Logger, +} -impl UserData for DiscardLogger { - fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { +#[salsa::db] +impl Database for DiscardLoggerDatabase { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { let event = event(); match event.kind { salsa::EventKind::WillDiscardStaleOutput { .. } | salsa::EventKind::DidDiscard { .. } => { - db.push_log(format!("salsa_event({:?})", event.kind)); + self.push_log(format!("salsa_event({:?})", event.kind)); } _ => {} } } } -impl HasLogger for DiscardLogger { +impl HasLogger for DiscardLoggerDatabase { fn logger(&self) -> &Logger { - &self.0 + &self.logger } } +#[salsa::db] #[derive(Default)] -pub struct ExecuteValidateLogger(Logger); +pub struct ExecuteValidateLoggerDatabase { + storage: Storage, + logger: Logger, +} -impl UserData for ExecuteValidateLogger { - fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { +#[salsa::db] +impl Database for ExecuteValidateLoggerDatabase { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { let event = event(); match event.kind { salsa::EventKind::WillExecute { .. } | salsa::EventKind::DidValidateMemoizedValue { .. } => { - db.push_log(format!("salsa_event({:?})", event.kind)); + self.push_log(format!("salsa_event({:?})", event.kind)); } _ => {} } } } -impl HasLogger for ExecuteValidateLogger { +impl HasLogger for ExecuteValidateLoggerDatabase { fn logger(&self) -> &Logger { - &self.0 + &self.logger } } diff --git a/tests/deletion-cascade.rs b/tests/deletion-cascade.rs index 023e584b..23576ad7 100644 --- a/tests/deletion-cascade.rs +++ b/tests/deletion-cascade.rs @@ -3,10 +3,10 @@ //! * when we delete memoized data, also delete outputs from that data mod common; -use common::{DiscardLogger, LogDatabase}; +use common::LogDatabase; use expect_test::expect; -use salsa::{DatabaseImpl, Setter}; +use salsa::Setter; use test_log::test; #[salsa::input(singleton)] @@ -50,7 +50,7 @@ fn copy_field<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { #[test] fn basic() { - let mut db: DatabaseImpl = Default::default(); + let mut db = common::DiscardLoggerDatabase::default(); // Creates 3 tracked structs let input = MyInput::new(&db, 3); diff --git a/tests/deletion.rs b/tests/deletion.rs index 74b69547..67d4caef 100644 --- a/tests/deletion.rs +++ b/tests/deletion.rs @@ -44,7 +44,7 @@ fn contribution_from_struct<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'d #[test] fn basic() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::DiscardLoggerDatabase::default(); // Creates 3 tracked structs let input = MyInput::new(&db, 3); diff --git a/tests/elided-lifetime-in-tracked-fn.rs b/tests/elided-lifetime-in-tracked-fn.rs index 07090d60..f62c23a8 100644 --- a/tests/elided-lifetime-in-tracked-fn.rs +++ b/tests/elided-lifetime-in-tracked-fn.rs @@ -2,10 +2,10 @@ //! compiles and executes successfully. mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use expect_test::expect; -use salsa::{DatabaseImpl, Setter}; +use salsa::Setter; use test_log::test; #[salsa::input] @@ -32,7 +32,7 @@ fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { #[test] fn execute() { - let mut db: DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); diff --git a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs index 5c110c21..fb62e1c5 100644 --- a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs @@ -4,10 +4,10 @@ #![allow(dead_code)] mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use expect_test::expect; -use salsa::{DatabaseImpl, Setter}; +use salsa::Setter; #[salsa::input] struct MyInput { @@ -43,7 +43,7 @@ fn execute() { // y = input.field / 2 // final_result_depends_on_x = x * 2 = (input.field + 1) / 2 * 2 // final_result_depends_on_y = y * 2 = input.field / 2 * 2 - let mut db: DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); // intermediate results: // x = (22 + 1) / 2 = 11 diff --git a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs index 12f2cfbf..a1844795 100644 --- a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs @@ -4,7 +4,7 @@ #![allow(dead_code)] mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use expect_test::expect; use salsa::Setter; @@ -30,7 +30,7 @@ fn result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> u32 { fn execute() { // result_depends_on_x = x + 1 // result_depends_on_y = y - 1 - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22, 33); assert_eq!(result_depends_on_x(&db, input), 23); diff --git a/tests/hello_world.rs b/tests/hello_world.rs index 3a316d1e..04cfdee9 100644 --- a/tests/hello_world.rs +++ b/tests/hello_world.rs @@ -2,7 +2,7 @@ //! compiles and executes successfully. mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use expect_test::expect; use salsa::Setter; @@ -32,7 +32,7 @@ fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { #[test] fn execute() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); @@ -63,7 +63,7 @@ fn execute() { /// Create and mutate a distinct input. No re-execution required. #[test] fn red_herring() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); diff --git a/tests/lru.rs b/tests/lru.rs index 2d5b1b5d..0cd36002 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -7,7 +7,7 @@ use std::sync::{ }; mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use salsa::Database as _; use test_log::test; @@ -61,7 +61,7 @@ fn load_n_potatoes() -> usize { #[test] fn lru_works() { - let db: salsa::DatabaseImpl = Default::default(); + let db = common::LoggerDatabase::default(); assert_eq!(load_n_potatoes(), 0); for i in 0..128u32 { @@ -77,7 +77,7 @@ fn lru_works() { #[test] fn lru_doesnt_break_volatile_queries() { - let db: salsa::DatabaseImpl = Default::default(); + let db = common::LoggerDatabase::default(); // Create all inputs first, so that there are no revision changes among calls to `get_volatile` let inputs: Vec = (0..128usize).map(|i| MyInput::new(&db, i as u32)).collect(); @@ -95,7 +95,7 @@ fn lru_doesnt_break_volatile_queries() { #[test] fn lru_can_be_changed_at_runtime() { - let db: salsa::DatabaseImpl = Default::default(); + let db = common::LoggerDatabase::default(); assert_eq!(load_n_potatoes(), 0); let inputs: Vec<(u32, MyInput)> = (0..128).map(|i| (i, MyInput::new(&db, i))).collect(); @@ -138,7 +138,7 @@ fn lru_can_be_changed_at_runtime() { #[test] fn lru_keeps_dependency_info() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let capacity = 32; // Invoke `get_hot_potato2` 33 times. This will (in turn) invoke diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index 55f81e8d..a106ec7d 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -3,7 +3,6 @@ //! both intra and cross thread. use salsa::Cancelled; -use salsa::DatabaseImpl; use salsa::Setter; use crate::setup::Knobs; @@ -43,8 +42,7 @@ fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { #[test] fn execute() { - let mut db = >::default(); - db.knobs().signal_on_will_block.store(3); + let mut db = Knobs::default(); let input = MyInput::new(&db, 1); @@ -53,6 +51,7 @@ fn execute() { move || a1(&db, input) }); + db.signal_on_did_cancel.store(2); input.set_field(&mut db).to(2); // Assert thread A *should* was cancelled diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs index ac20b504..9dc8c74e 100644 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -2,8 +2,6 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::DatabaseImpl; - use crate::setup::Knobs; use crate::setup::KnobsDatabase; @@ -86,13 +84,13 @@ fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i3 #[test] fn execute() { - let db = >::default(); - db.knobs().signal_on_will_block.store(3); + let db = Knobs::default(); let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); + db.knobs().signal_on_will_block.store(3); move || a1(&db, input) }); diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs index 8bca2f61..593d46a6 100644 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -2,8 +2,6 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::DatabaseImpl; - use crate::setup::{Knobs, KnobsDatabase}; #[salsa::input] @@ -81,8 +79,7 @@ fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i3 #[test] fn execute() { - let db = >::default(); - db.knobs().signal_on_will_block.store(3); + let db = Knobs::default(); let input = MyInput::new(&db, 1); @@ -93,6 +90,7 @@ fn execute() { let thread_b = std::thread::spawn({ let db = db.clone(); + db.knobs().signal_on_will_block.store(3); move || b1(&db, input) }); diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs index d74aa5b0..bcd0ea58 100644 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -6,7 +6,6 @@ use crate::setup::Knobs; use crate::setup::KnobsDatabase; use expect_test::expect; use salsa::Database; -use salsa::DatabaseImpl; #[salsa::input] pub(crate) struct MyInput { @@ -37,13 +36,13 @@ pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { #[test] fn execute() { - let db = >::default(); - db.knobs().signal_on_will_block.store(3); + let db = Knobs::default(); let input = MyInput::new(&db, -1); let thread_a = std::thread::spawn({ let db = db.clone(); + db.knobs().signal_on_will_block.store(3); move || a(&db, input) }); diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs index 2bf53857..c0378282 100644 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ b/tests/parallel/parallel_cycle_one_recover.rs @@ -2,8 +2,6 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::DatabaseImpl; - use crate::setup::{Knobs, KnobsDatabase}; #[salsa::input] @@ -70,13 +68,13 @@ pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { #[test] fn execute() { - let db = >::default(); - db.knobs().signal_on_will_block.store(3); + let db = Knobs::default(); let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); + db.knobs().signal_on_will_block.store(3); move || a1(&db, input) }); diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 6410f853..56d204ee 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -1,5 +1,7 @@ +use std::sync::Arc; + use crossbeam::atomic::AtomicCell; -use salsa::{Database, DatabaseImpl, UserData}; +use salsa::Database; use crate::signal::Signal; @@ -14,14 +16,17 @@ pub(crate) trait KnobsDatabase: Database { fn wait_for(&self, stage: usize); } -/// Various "knobs" that can be used to customize how the queries +/// A database containing various "knobs" that can be used to customize how the queries /// behave on one specific thread. Note that this state is /// intentionally thread-local (apart from `signal`). +#[salsa::db] #[derive(Default)] pub(crate) struct Knobs { + storage: salsa::Storage, + /// A kind of flexible barrier used to coordinate execution across /// threads to ensure we reach various weird states. - pub(crate) signal: Signal, + pub(crate) signal: Arc, /// When this database is about to block, send this signal. pub(crate) signal_on_will_block: AtomicCell, @@ -30,15 +35,31 @@ pub(crate) struct Knobs { pub(crate) signal_on_did_cancel: AtomicCell, } -impl UserData for Knobs { - fn salsa_event(db: &DatabaseImpl, event: &dyn Fn() -> salsa::Event) { +impl Clone for Knobs { + #[track_caller] + fn clone(&self) -> Self { + // To avoid mistakes, check that when we clone, we haven't customized this behavior yet + assert_eq!(self.signal_on_will_block.load(), 0); + assert_eq!(self.signal_on_did_cancel.load(), 0); + Self { + storage: self.storage.clone(), + signal: self.signal.clone(), + signal_on_will_block: AtomicCell::new(0), + signal_on_did_cancel: AtomicCell::new(0), + } + } +} + +#[salsa::db] +impl salsa::Database for Knobs { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { let event = event(); match event.kind { salsa::EventKind::WillBlockOn { .. } => { - db.signal(db.signal_on_will_block.load()); + self.signal(self.signal_on_will_block.load()); } salsa::EventKind::DidSetCancellationFlag => { - db.signal(db.signal_on_did_cancel.load()); + self.signal(self.signal_on_did_cancel.load()); } _ => {} } @@ -46,7 +67,7 @@ impl UserData for Knobs { } #[salsa::db] -impl KnobsDatabase for DatabaseImpl { +impl KnobsDatabase for Knobs { fn knobs(&self) -> &Knobs { self } diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 99391709..d0946dc4 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -3,9 +3,9 @@ use std::cell::Cell; +use common::LogDatabase; use expect_test::expect; mod common; -use common::{EventLogger, LogDatabase}; use salsa::{Database, Setter}; use test_log::test; @@ -44,7 +44,7 @@ fn function(db: &dyn Database, input: MyInput) -> usize { #[test] fn test_leaked_inputs_ignored() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::EventLoggerDatabase::default(); let input = MyInput::new(&db, 10, 20); let result_in_rev_1 = function(&db, input); diff --git a/tests/synthetic_write.rs b/tests/synthetic_write.rs index b7629280..cf036bdc 100644 --- a/tests/synthetic_write.rs +++ b/tests/synthetic_write.rs @@ -4,7 +4,7 @@ mod common; -use common::{ExecuteValidateLogger, LogDatabase, Logger}; +use common::{LogDatabase, Logger}; use expect_test::expect; use salsa::{Database, DatabaseImpl, Durability, Event, EventKind}; @@ -20,7 +20,7 @@ fn tracked_fn(db: &dyn Database, input: MyInput) -> u32 { #[test] fn execute() { - let mut db: DatabaseImpl = Default::default(); + let mut db = common::ExecuteValidateLoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(tracked_fn(&db, input), 44); diff --git a/tests/tracked-struct-value-field-bad-eq.rs b/tests/tracked-struct-value-field-bad-eq.rs index ec4cac5f..dec73366 100644 --- a/tests/tracked-struct-value-field-bad-eq.rs +++ b/tests/tracked-struct-value-field-bad-eq.rs @@ -54,7 +54,7 @@ fn read_tracked_struct<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> b #[test] fn execute() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::ExecuteValidateLoggerDatabase::default(); let input = MyInput::new(&db, true); let result = the_fn(&db, input); diff --git a/tests/tracked_fn_no_eq.rs b/tests/tracked_fn_no_eq.rs index ee7c5651..6f223b79 100644 --- a/tests/tracked_fn_no_eq.rs +++ b/tests/tracked_fn_no_eq.rs @@ -1,8 +1,8 @@ mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use expect_test::expect; -use salsa::{DatabaseImpl, Setter as _}; +use salsa::Setter as _; #[salsa::input] struct Input { @@ -26,7 +26,7 @@ fn derived(db: &dyn LogDatabase, input: Input) -> u32 { } #[test] fn invoke() { - let mut db: DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = Input::new(&db, 5); let x = derived(&db, input); diff --git a/tests/tracked_fn_read_own_entity.rs b/tests/tracked_fn_read_own_entity.rs index ad4a7002..48ed793d 100644 --- a/tests/tracked_fn_read_own_entity.rs +++ b/tests/tracked_fn_read_own_entity.rs @@ -3,7 +3,7 @@ use expect_test::expect; mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use salsa::Setter; use test_log::test; @@ -33,7 +33,7 @@ fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { #[test] fn one_entity() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); @@ -64,7 +64,7 @@ fn one_entity() { /// Create and mutate a distinct input. No re-execution required. #[test] fn red_herring() { - let mut db: salsa::DatabaseImpl = Default::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(final_result(&db, input), 22); diff --git a/tests/tracked_fn_read_own_specify.rs b/tests/tracked_fn_read_own_specify.rs index 0c643366..426d18a7 100644 --- a/tests/tracked_fn_read_own_specify.rs +++ b/tests/tracked_fn_read_own_specify.rs @@ -1,6 +1,6 @@ use expect_test::expect; mod common; -use common::{LogDatabase, Logger}; +use common::LogDatabase; use salsa::Database; #[salsa::input] @@ -29,7 +29,7 @@ fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: MyTracked<'db>) -> u32 { #[test] fn execute() { - let mut db: salsa::DatabaseImpl = salsa::DatabaseImpl::default(); + let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); assert_eq!(tracked_fn(&db, input), 2222); db.assert_logs(expect![[r#" From cafbe9247b276b9b7db13a225731d4953c5b4de1 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 4 Aug 2024 02:22:27 -0400 Subject: [PATCH 23/29] update debug output --- examples/calc/parser.rs | 48 +++++++++++------------------------------ 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/examples/calc/parser.rs b/examples/calc/parser.rs index 40077361..5c9b6675 100644 --- a/examples/calc/parser.rs +++ b/examples/calc/parser.rs @@ -401,9 +401,7 @@ fn parse_print() { end: 7, }, data: Number( - OrderedFloat( - 1.0, - ), + 1.0, ), }, Add, @@ -414,9 +412,7 @@ fn parse_print() { end: 11, }, data: Number( - OrderedFloat( - 2.0, - ), + 2.0, ), }, ), @@ -552,9 +548,7 @@ fn parse_example() { end: 81, }, data: Number( - OrderedFloat( - 3.14, - ), + 3.14, ), }, Multiply, @@ -615,9 +609,7 @@ fn parse_example() { end: 124, }, data: Number( - OrderedFloat( - 3.0, - ), + 3.0, ), }, Expression { @@ -627,9 +619,7 @@ fn parse_example() { end: 127, }, data: Number( - OrderedFloat( - 4.0, - ), + 4.0, ), }, ], @@ -662,9 +652,7 @@ fn parse_example() { end: 160, }, data: Number( - OrderedFloat( - 1.0, - ), + 1.0, ), }, ], @@ -693,9 +681,7 @@ fn parse_example() { end: 182, }, data: Number( - OrderedFloat( - 11.0, - ), + 11.0, ), }, Multiply, @@ -706,9 +692,7 @@ fn parse_example() { end: 186, }, data: Number( - OrderedFloat( - 2.0, - ), + 2.0, ), }, ), @@ -782,9 +766,7 @@ fn parse_precedence() { end: 7, }, data: Number( - OrderedFloat( - 1.0, - ), + 1.0, ), }, Add, @@ -802,9 +784,7 @@ fn parse_precedence() { end: 11, }, data: Number( - OrderedFloat( - 2.0, - ), + 2.0, ), }, Multiply, @@ -815,9 +795,7 @@ fn parse_precedence() { end: 15, }, data: Number( - OrderedFloat( - 3.0, - ), + 3.0, ), }, ), @@ -832,9 +810,7 @@ fn parse_precedence() { end: 19, }, data: Number( - OrderedFloat( - 4.0, - ), + 4.0, ), }, ), From 9f95b37af942f90302c086a164bc1c420c77390e Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 10:52:39 +0000 Subject: [PATCH 24/29] add a justfile for convenience This should really be synchronized with the codespaces and github configuration but... I'm not clever enough to do all that. --- justfile | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 justfile diff --git a/justfile b/justfile new file mode 100644 index 00000000..e80aa4c1 --- /dev/null +++ b/justfile @@ -0,0 +1,7 @@ +test: + cargo test --workspace --all-features --all-targets + +miri: + cargo +nightly miri test --no-fail-fast --all-features + +all: test miri \ No newline at end of file From bca9180e058a74569018680d85a2af263fe57886 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 11:29:26 +0000 Subject: [PATCH 25/29] just cache the index --- src/zalsa.rs | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/zalsa.rs b/src/zalsa.rs index 0271b91e..c91a02d0 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,4 +1,5 @@ use std::any::{Any, TypeId}; +use std::marker::PhantomData; use std::thread::ThreadId; use orx_concurrent_vec::ConcurrentVec; @@ -242,7 +243,8 @@ pub struct IngredientCache where I: Ingredient, { - cached_data: std::sync::OnceLock<(Nonce, *const I)>, + cached_data: std::sync::OnceLock<(Nonce, IngredientIndex)>, + phantom: PhantomData I>, } unsafe impl Sync for IngredientCache where I: Ingredient + Sync {} @@ -264,6 +266,7 @@ where pub const fn new() -> Self { Self { cached_data: std::sync::OnceLock::new(), + phantom: PhantomData, } } @@ -274,24 +277,24 @@ where db: &'s dyn Database, create_index: impl Fn() -> IngredientIndex, ) -> &'s I { - let &(nonce, ingredient) = self.cached_data.get_or_init(|| { - let ingredient = self.create_ingredient(db, &create_index); - (db.zalsa().nonce(), ingredient as *const I) + let zalsa = db.zalsa(); + let (nonce, index) = self.cached_data.get_or_init(|| { + let index = create_index(); + (zalsa.nonce(), index) }); - if db.zalsa().nonce() == nonce { - unsafe { &*ingredient } + // FIXME: We used to cache a raw pointer to the revision but miri + // was reporting errors because that pointer was derived from an `&` + // that is invalidated when the next revision starts with an `&mut`. + // + // We could fix it with orxfun/orx-concurrent-vec#18 or by "refreshing" the cache + // when the revision changes but just caching the index is an awful lot simpler. + + if db.zalsa().nonce() == *nonce { + zalsa.lookup_ingredient(*index).assert_type::() } else { - self.create_ingredient(db, &create_index) + let index = create_index(); + zalsa.lookup_ingredient(index).assert_type::() } } - - fn create_ingredient<'s>( - &self, - storage: &'s dyn Database, - create_index: &impl Fn() -> IngredientIndex, - ) -> &'s I { - let index = create_index(); - storage.zalsa().lookup_ingredient(index).assert_type::() - } } From 83be1e4877d7a51728a34dcb678aeff29266bb3a Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 28 Jul 2024 12:34:04 +0000 Subject: [PATCH 26/29] make the Views type miri-safe and add more comments --- src/views.rs | 115 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 6 deletions(-) diff --git a/src/views.rs b/src/views.rs index 19798737..5fd08c2f 100644 --- a/src/views.rs +++ b/src/views.rs @@ -7,18 +7,82 @@ use orx_concurrent_vec::ConcurrentVec; use crate::Database; +/// A `Views` struct is associated with some specific database type +/// (a `DatabaseImpl` for some existential `U`). It contains functions +/// to downcast from that type to `dyn DbView` for various traits `DbView`. +/// None of these types are known at compilation time, they are all checked +/// dynamically through `TypeId` magic. +/// +/// You can think of the struct as looking like: +/// +/// ```rust,ignore +/// struct Views { +/// source_type_id: TypeId, // `TypeId` for `Db` +/// view_casters: Arc { +/// ViewCaster +/// }>>, +/// } +/// ``` #[derive(Clone)] pub struct Views { source_type_id: TypeId, view_casters: Arc>, } +/// A ViewCaster contains a trait object that can cast from the +/// (ghost) `Db` type of `Views` to some (ghost) `DbView` type. +/// +/// You can think of the struct as looking like: +/// +/// ```rust,ignore +/// struct ViewCaster { +/// target_type_id: TypeId, // TypeId of DbView +/// type_name: &'static str, // type name of DbView +/// cast_to: OpaqueBoxDyn, // a `Box>` that expects a `Db` +/// free_box: Box, // the same box as above, but upcast to `dyn Free` +/// } +/// ``` +/// +/// As you can see, we have to work very hard to manage things +/// in a way that miri is happy with. What is going on here? +/// +/// * The `cast_to` is the cast object, but we can't actually name its type, so +/// we transmute it into some opaque bytes. We can transmute it back once we +/// are in a function monormophized over some function `T` that has the same type-id +/// as `target_type_id`. +/// * The problem is that dropping `cast_to` has no effect and we need +/// to free the box! To do that, we *also* upcast the box to a `Box`. +/// This trait has no purpose but to carry a destructor. struct ViewCaster { + /// The id of the target type `DbView` that we can cast to. target_type_id: TypeId, + + /// The name of the target type `DbView` that we can cast to. type_name: &'static str, - func: fn(&Dummy) -> &Dummy, + + /// A "type-obscured" `Box>`, where `DbView` + /// is the type whose id is encoded in `target_type_id`. + cast_to: OpaqueBoxDyn, + + /// An upcasted version of `cast_to`; the only purpose of this field is + /// to be dropped in the destructor, see `ViewCaster` comment. + #[allow(dead_code)] + free_box: Box, } +type OpaqueBoxDyn = [u8; std::mem::size_of::>>()]; + +trait CastTo: Free { + /// # Safety requirement + /// + /// `db` must have a data pointer whose type is the `Db` type for `Self` + unsafe fn cast<'db>(&self, db: &'db dyn Database) -> &'db DbView; + + fn into_box_free(self: Box) -> Box; +} + +trait Free: Send + Sync {} + #[allow(dead_code)] enum Dummy {} @@ -45,10 +109,21 @@ impl Views { return; } + let cast_to: Box> = Box::new(func); + let cast_to: OpaqueBoxDyn = + unsafe { std::mem::transmute::>, OpaqueBoxDyn>(cast_to) }; + + // Create a second copy of `cast_to` (which is now `Copy`) and upcast it to a `Box`. + // We will drop this box to run the destructor. + let free_box: Box = unsafe { + std::mem::transmute::>>(cast_to).into_box_free() + }; + self.view_casters.push(ViewCaster { target_type_id, type_name: std::any::type_name::(), - func: unsafe { std::mem::transmute:: &DbView, fn(&Dummy) -> &Dummy>(func) }, + cast_to, + free_box, }); } @@ -73,8 +148,12 @@ impl Views { // While the compiler doesn't know what `X` is at this point, we know it's the // same as the true type of `db_data_ptr`, and the memory representation for `()` // and `&X` are the same (since `X` is `Sized`). - let func: fn(&()) -> &DbView = unsafe { std::mem::transmute(caster.func) }; - return Some(func(data_ptr(db))); + let cast_to: &OpaqueBoxDyn = &caster.cast_to; + unsafe { + let cast_to = + std::mem::transmute::<&OpaqueBoxDyn, &Box>>(cast_to); + return Some(cast_to.cast(db)); + }; } } @@ -98,8 +177,32 @@ impl std::fmt::Debug for ViewCaster { /// Given a wide pointer `T`, extracts the data pointer (typed as `()`). /// This is safe because `()` gives no access to any data and has no validity requirements in particular. -fn data_ptr(t: &T) -> &() { +unsafe fn data_ptr(t: &T) -> &U { let t: *const T = t; - let u: *const () = t as *const (); + let u: *const U = t as *const U; unsafe { &*u } } + +impl CastTo for fn(&Db) -> &DbView +where + Db: Database, + DbView: ?Sized + Any, +{ + unsafe fn cast<'db>(&self, db: &'db dyn Database) -> &'db DbView { + // This tests the safety requirement: + debug_assert_eq!(db.type_id(), TypeId::of::()); + + // SAFETY: + // + // Caller guarantees that the input is of type `Db` + // (we test it in the debug-assertion above). + let db = unsafe { data_ptr::(db) }; + (*self)(db) + } + + fn into_box_free(self: Box) -> Box { + self + } +} + +impl Free for T {} From 118e89ce2038066ec37232bb4b4e55510af40d22 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 4 Aug 2024 03:02:28 -0400 Subject: [PATCH 27/29] add `ingredient_debug_name` API --- src/database.rs | 20 ++++++++++++++++++-- src/lib.rs | 1 + 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/database.rs b/src/database.rs index 801e2cd8..5a32bd9b 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,9 @@ -use std::any::Any; +use std::{any::Any, borrow::Cow}; -use crate::{zalsa::ZalsaDatabase, Durability, Event, Revision}; +use crate::{ + zalsa::{IngredientIndex, ZalsaDatabase}, + Durability, Event, Revision, +}; /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro. @@ -38,6 +41,19 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { zalsa_local.report_untracked_read(db.zalsa().current_revision()) } + /// Return the "debug name" (i.e., the struct name, etc) for an "ingredient", + /// which are the fine-grained components we use to track data. This is intended + /// for debugging and the contents of the returned string are not semver-guaranteed. + /// + /// Ingredient indices can be extracted from [`DependencyIndex`](`crate::DependencyIndex`) values. + fn ingredient_debug_name(&self, ingredient_index: IngredientIndex) -> Cow<'_, str> { + Cow::Borrowed( + self.zalsa() + .lookup_ingredient(ingredient_index) + .debug_name(), + ) + } + /// Execute `op` with the database in thread-local storage for debug print-outs. fn attach(&self, op: impl FnOnce(&Self) -> R) -> R where diff --git a/src/lib.rs b/src/lib.rs index d1b22e70..f701fc1e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,7 @@ pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::storage::Storage; pub use self::update::Update; +pub use self::zalsa::IngredientIndex; pub use crate::attach::with_attached_database; pub use salsa_macros::accumulator; pub use salsa_macros::db; From 6ff1975e1700c80ccc4daaf512d86da095683b0d Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 4 Aug 2024 03:04:40 -0400 Subject: [PATCH 28/29] Update components/salsa-macro-rules/src/setup_input_struct.rs Co-authored-by: Micha Reiser --- components/salsa-macro-rules/src/setup_input_struct.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index d89d63d6..ab2fbc63 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -159,9 +159,9 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let (ingredient, runtime) = $Configuration::ingredient_mut(db.as_dyn_database_mut()); + let (ingredient, revision) = $Configuration::ingredient_mut(db.as_dyn_database_mut()); $zalsa::input::SetterImpl::new( - runtime, + revision, self, $field_index, ingredient, From 1bce41f5d68903ca03c7e1222977c5c8f7df24d1 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 4 Aug 2024 04:21:06 -0400 Subject: [PATCH 29/29] stop ignoring miri Fixes #520 --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 723f8b68..1a792ed5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,7 +61,6 @@ jobs: miri: name: Miri - continue-on-error: true # FIXME: https://github.com/salsa-rs/salsa/issues/520 runs-on: ubuntu-latest steps: - name: Checkout