diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index cd9a9f21e..fcc641863 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -142,12 +142,13 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { $FN_CACHE.get_or_create(db.as_dyn_database(), || { - ::zalsa_db(db); + ::zalsa_register_downcaster(db); db.zalsa().add_or_lookup_jar_by_type::<$Configuration>() }) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { + ::zalsa_register_downcaster(db); let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.add_or_lookup_jar_by_type::<$Configuration>(); let (ingredient, _) = zalsa_mut.lookup_ingredient_mut(index); @@ -159,6 +160,7 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { $INTERN_CACHE.get_or_create(db.as_dyn_database(), || { + ::zalsa_register_downcaster(db); db.zalsa().add_or_lookup_jar_by_type::<$Configuration>().successor(0) }) } @@ -249,7 +251,8 @@ macro_rules! setup_tracked_fn { let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( first_index, memo_ingredient_indices, - $lru + $lru, + zalsa.views().downcaster_for::() ); $zalsa::macro_if! { if $needs_interner { diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index e9f553bb7..6de2d3b79 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -89,10 +89,12 @@ impl DbMacro { use salsa::plumbing as #zalsa; unsafe impl #zalsa::HasStorage for #db { + #[inline(always)] fn storage(&self) -> &#zalsa::Storage { &self.#storage } + #[inline(always)] fn storage_mut(&mut self) -> &mut #zalsa::Storage { &mut self.#storage } @@ -102,16 +104,26 @@ impl DbMacro { } fn add_salsa_view_method(&self, input: &mut syn::ItemTrait) -> syn::Result<()> { + let trait_name = &input.ident; input.items.push(parse_quote! { #[doc(hidden)] - fn zalsa_db(&self); + fn zalsa_register_downcaster(&self); + }); + + let comment = format!(" Downcast a [`dyn Database`] to a [`dyn {trait_name}`]"); + input.items.push(parse_quote! { + #[doc = #comment] + /// + /// # Safety + /// + /// The input database must be of type `Self`. + #[doc(hidden)] + unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #trait_name where Self: Sized; }); Ok(()) } fn add_salsa_view_method_impl(&self, input: &mut syn::ItemImpl) -> syn::Result<()> { - let zalsa = self.hygiene.ident("zalsa"); - let Some((_, TraitPath, _)) = &input.trait_ else { return Err(syn::Error::new_spanned( &input.self_ty, @@ -121,9 +133,18 @@ impl DbMacro { input.items.push(parse_quote! { #[doc(hidden)] - fn zalsa_db(&self) { - use salsa::plumbing as #zalsa; - #zalsa::views(self).add::(|t| t); + #[inline(always)] + fn zalsa_register_downcaster(&self) { + salsa::plumbing::views(self).add(::downcast); + } + }); + input.items.push(parse_quote! { + #[doc(hidden)] + #[inline(always)] + unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #TraitPath where Self: Sized { + debug_assert_eq!(db.type_id(), ::core::any::TypeId::of::()); + // SAFETY: Same as the safety of the `downcast` method. + unsafe { &*salsa::plumbing::transmute_data_ptr::(db) } } }); Ok(()) diff --git a/components/salsa-macros/src/lib.rs b/components/salsa-macros/src/lib.rs index e86c1bfc5..2d0a8c7d5 100644 --- a/components/salsa-macros/src/lib.rs +++ b/components/salsa-macros/src/lib.rs @@ -2,8 +2,6 @@ #![recursion_limit = "256"] -extern crate proc_macro; -extern crate proc_macro2; #[macro_use] extern crate quote; diff --git a/src/accumulator.rs b/src/accumulator.rs index 2191dd198..e6f7e640d 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -100,7 +100,7 @@ impl Ingredient for IngredientImpl { self.index } - fn maybe_changed_after( + unsafe fn maybe_changed_after( &self, _db: &dyn Database, _input: Id, diff --git a/src/database.rs b/src/database.rs index a16657b0d..1891171fd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -121,6 +121,7 @@ impl dyn Database { /// If the view has not been added to the database (see [`crate::views::Views`]). #[track_caller] pub fn as_view(&self) -> &DbView { - self.zalsa().views().try_view_as(self).unwrap() + let views = self.zalsa().views(); + views.downcaster_for().downcast(self) } } diff --git a/src/function.rs b/src/function.rs index ba25d96f2..b08045a8f 100644 --- a/src/function.rs +++ b/src/function.rs @@ -8,6 +8,7 @@ use crate::{ plumbing::MemoIngredientMap, salsa_struct::SalsaStructInDb, table::Table, + views::DatabaseDownCaster, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Cycle, Database, Id, Revision, @@ -105,6 +106,14 @@ pub struct IngredientImpl { /// Used to find memos to throw out when we have too many memoized values. lru: lru::Lru, + /// A downcaster from `dyn Database` to `C::DbView`. + /// + /// # Safety + /// + /// The supplied database must be be the same as the database used to construct the [`Views`] + /// instances that this downcaster was derived from. + view_caster: DatabaseDownCaster, + /// When `fetch` and friends executes, they return a reference to the /// value stored in the memo that is extended to live as long as the `&self` /// reference we start with. This means that whenever we remove something @@ -135,12 +144,14 @@ where index: IngredientIndex, memo_ingredient_indices: as SalsaStructInDb>::MemoIngredientMap, lru: usize, + view_caster: DatabaseDownCaster, ) -> Self { Self { index, memo_ingredient_indices, lru: lru::Lru::new(lru), deleted_entries: Default::default(), + view_caster, } } @@ -213,13 +224,14 @@ where self.index } - fn maybe_changed_after( + unsafe fn maybe_changed_after( &self, db: &dyn Database, input: Id, revision: Revision, ) -> MaybeChangedAfter { - let db = db.as_view::(); + // SAFETY: The `db` belongs to the ingredient as per caller invariant + let db = unsafe { self.view_caster.downcast_unchecked(db) }; self.maybe_changed_after(db, input, revision) } @@ -279,7 +291,7 @@ where db: &'db dyn Database, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { - let db = db.as_view::(); + let db = self.view_caster.downcast(db); self.accumulated_map(db, key_index) } } diff --git a/src/ingredient.rs b/src/ingredient.rs index 037e004df..dba70f644 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -51,7 +51,11 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { fn debug_name(&self) -> &'static str; /// Has the value for `input` in this ingredient changed after `revision`? - fn maybe_changed_after<'db>( + /// + /// # Safety + /// + /// The passed in database needs to be the same one that the ingredient was created with. + unsafe fn maybe_changed_after<'db>( &'db self, db: &'db dyn Database, input: Id, diff --git a/src/input.rs b/src/input.rs index b8b7473be..cf99e9702 100644 --- a/src/input.rs +++ b/src/input.rs @@ -215,7 +215,7 @@ impl Ingredient for IngredientImpl { self.ingredient_index } - fn maybe_changed_after( + unsafe fn maybe_changed_after( &self, _db: &dyn Database, _input: Id, diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 53e82a890..362d3675c 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -49,7 +49,7 @@ where CycleRecoveryStrategy::Panic } - fn maybe_changed_after( + unsafe fn maybe_changed_after( &self, db: &dyn Database, input: Id, diff --git a/src/interned.rs b/src/interned.rs index 9340323b4..86df09c5b 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -282,7 +282,7 @@ where self.ingredient_index } - fn maybe_changed_after( + unsafe fn maybe_changed_after( &self, _db: &dyn Database, _input: Id, diff --git a/src/key.rs b/src/key.rs index 2476a5e88..f3e90bb10 100644 --- a/src/key.rs +++ b/src/key.rs @@ -100,10 +100,12 @@ impl InputDependencyIndex { last_verified_at: crate::Revision, ) -> MaybeChangedAfter { match self.key_index { - Some(key_index) => db - .zalsa() - .lookup_ingredient(self.ingredient_index) - .maybe_changed_after(db, key_index, last_verified_at), + // SAFETY: The `db` belongs to the ingredient + Some(key_index) => unsafe { + db.zalsa() + .lookup_ingredient(self.ingredient_index) + .maybe_changed_after(db, key_index, last_verified_at) + }, // Data in tables themselves remain valid until the table as a whole is reset. None => MaybeChangedAfter::No(InputAccumulatedValues::Empty), } diff --git a/src/lib.rs b/src/lib.rs index bf6ff4e8f..2c9eb6a00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![forbid(unsafe_op_in_unsafe_fn)] +extern crate self as salsa; + mod accumulator; mod active_query; mod array; @@ -107,6 +109,7 @@ pub mod plumbing { pub use crate::update::helper::Dispatch as UpdateDispatch; pub use crate::update::helper::Fallback as UpdateFallback; pub use crate::update::Update; + pub use crate::zalsa::transmute_data_ptr; pub use crate::zalsa::views; pub use crate::zalsa::IngredientCache; pub use crate::zalsa::IngredientIndex; diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index ec5ae11a7..c88a47b04 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -735,7 +735,7 @@ where self.ingredient_index } - fn maybe_changed_after( + unsafe fn maybe_changed_after( &self, db: &dyn Database, input: Id, diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 24125e2c4..b69ffebd1 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -55,7 +55,7 @@ where crate::cycle::CycleRecoveryStrategy::Panic } - fn maybe_changed_after<'db>( + unsafe fn maybe_changed_after<'db>( &'db self, db: &'db dyn Database, input: Id, diff --git a/src/views.rs b/src/views.rs index 9a5d938ac..c9e107bb8 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,92 +1,81 @@ -use crate::{zalsa::transmute_data_ptr, Database}; -use std::{ - any::{Any, TypeId}, - sync::Arc, -}; +use std::any::{Any, TypeId}; + +use crate::Database; /// A `Views` struct is associated with some specific database type /// (a `DatabaseImpl` for some existential `U`). It contains functions -/// to downcast from that type to `dyn DbView` for various traits `DbView`. +/// to downcast from `dyn Database` to `dyn DbView` for various traits `DbView` via this specific +/// database type. /// None of these types are known at compilation time, they are all checked /// dynamically through `TypeId` magic. -/// -/// You can think of the struct as looking like: -/// -/// ```rust,ignore -/// struct Views { -/// source_type_id: TypeId, // `TypeId` for `Db` -/// view_casters: Arc { -/// ViewCaster -/// }>>, -/// } -/// ``` -#[derive(Clone)] pub struct Views { source_type_id: TypeId, - view_casters: Arc>, + view_casters: boxcar::Vec, } -/// A DynViewCaster contains a manual trait object that can cast from the -/// (ghost) `Db` type of `Views` to some (ghost) `DbView` type. -/// -/// You can think of the struct as looking like: -/// -/// ```rust,ignore -/// struct DynViewCaster { -/// target_type_id: TypeId, // TypeId of DbView -/// type_name: &'static str, // type name of DbView -/// view_caster: *mut (), // a `Box>` -/// cast: *const (), // a `unsafe fn (&ViewCaster, &dyn Database) -> &DbView` -/// drop: *const (), // the destructor for the box above -/// } -/// ``` -/// -/// The manual trait object and vtable allows for type erasure without -/// transmuting between fat pointers, whose layout is undefined. -struct DynViewCaster { - /// The id of the target type `DbView` that we can cast to. +struct ViewCaster { + /// The id of the target type `dyn DbView` that we can cast to. target_type_id: TypeId, - /// The name of the target type `DbView` that we can cast to. + /// The name of the target type `dyn DbView` that we can cast to. type_name: &'static str, - /// A pointer to a `ViewCaster`. - view_caster: *mut (), + /// Type-erased function pointer that downcasts from `dyn Database` to `dyn DbView`. + cast: ErasedDatabaseDownCasterSig, +} - /// Type-erased `ViewCaster::::vtable_cast`. - cast: *const (), +type ErasedDatabaseDownCasterSig = unsafe fn(&dyn Database) -> *const (); +type DatabaseDownCasterSig = unsafe fn(&dyn Database) -> &DbView; - /// Type-erased `ViewCaster::::drop`. - drop: unsafe fn(*mut ()), -} +pub struct DatabaseDownCaster(TypeId, DatabaseDownCasterSig); -impl Drop for DynViewCaster { - fn drop(&mut self) { - // SAFETY: We own `self.caster` and are in the destructor. - unsafe { (self.drop)(self.view_caster) }; +impl DatabaseDownCaster { + pub fn downcast<'db>(&self, db: &'db dyn Database) -> &'db DbView { + assert_eq!( + self.0, + db.type_id(), + "Database type does not match the expected type for this `Views` instance" + ); + // SAFETY: We've asserted that the database is correct. + unsafe { (self.1)(db) } } -} -// SAFETY: These traits can be implemented normally as the raw pointers -// in `DynViewCaster` are only used for type-erasure. -unsafe impl Send for DynViewCaster {} -unsafe impl Sync for DynViewCaster {} + /// Downcast `db` to `DbView`. + /// + /// # Safety + /// + /// The caller must ensure that `db` is of the correct type. + pub unsafe fn downcast_unchecked<'db>(&self, db: &'db dyn Database) -> &'db DbView { + unsafe { (self.1)(db) } + } +} impl Views { pub(crate) fn new() -> Self { let source_type_id = TypeId::of::(); + let view_casters = boxcar::Vec::new(); + // special case the no-op transformation, that way we skip out on reconstructing the wide pointer + view_casters.push(ViewCaster { + target_type_id: TypeId::of::(), + type_name: std::any::type_name::(), + // SAFETY: We are type erasing for storage, taking care of unerasing before we call + // the function pointer. + cast: unsafe { + std::mem::transmute::< + DatabaseDownCasterSig, + ErasedDatabaseDownCasterSig, + >(|db| db) + }, + }); Self { source_type_id, - view_casters: Arc::new(boxcar::Vec::new()), + view_casters, } } - /// Add a new upcast from `Db` to `T`, given the upcasting function `func`. - pub fn add(&self, func: fn(&Db) -> &DbView) { - assert_eq!(self.source_type_id, TypeId::of::(), "dyn-upcasts"); - + /// Add a new downcaster from `dyn Database` to `dyn DbView`. + pub fn add(&self, func: DatabaseDownCasterSig) { let target_type_id = TypeId::of::(); - if self .view_casters .iter() @@ -94,91 +83,42 @@ impl Views { { return; } - - let view_caster = Box::into_raw(Box::new(ViewCaster(func))); - - self.view_casters.push(DynViewCaster { + self.view_casters.push(ViewCaster { target_type_id, type_name: std::any::type_name::(), - view_caster: view_caster.cast(), - cast: ViewCaster::::erased_cast as _, - drop: ViewCaster::::erased_drop, + // SAFETY: We are type erasing for storage, taking care of unerasing before we call + // the function pointer. + cast: unsafe { + std::mem::transmute::, ErasedDatabaseDownCasterSig>( + func, + ) + }, }); } - /// Convert one handle to a salsa database (including a `dyn Database`!) to another. + /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`. /// /// # Panics /// /// If the underlying type of `db` is not the same as the database type this upcasts was created for. - pub fn try_view_as<'db, DbView: ?Sized + Any>( - &self, - db: &'db dyn Database, - ) -> Option<&'db DbView> { - let db_type_id = ::type_id(db); - assert_eq!(self.source_type_id, db_type_id, "database type mismatch"); - + pub fn downcaster_for(&self) -> DatabaseDownCaster { let view_type_id = TypeId::of::(); for (_idx, view) in self.view_casters.iter() { if view.target_type_id == view_type_id { - // SAFETY: We verified that this is the view caster for the - // `DbView` type by checking type IDs above. - let view = unsafe { - let caster: unsafe fn(*const (), &dyn Database) -> &DbView = - std::mem::transmute(view.cast); - caster(view.view_caster, db) - }; - - return Some(view); + // SAFETY: We are unerasing the type erased function pointer having made sure the + // TypeId matches. + return DatabaseDownCaster(self.source_type_id, unsafe { + std::mem::transmute::>( + view.cast, + ) + }); } } - None - } -} - -/// A generic downcaster for specific `Db` and `DbView` types. -struct ViewCaster(fn(&Db) -> &DbView); - -impl ViewCaster -where - Db: Database, - DbView: ?Sized + Any, -{ - /// Obtain a reference of type `DbView` from a database. - /// - /// # Safety - /// - /// The input database must be of type `Db`. - unsafe fn cast<'db>(&self, db: &'db dyn Database) -> &'db DbView { - // This tests the safety requirement: - debug_assert_eq!(db.type_id(), TypeId::of::()); - - // SAFETY: - // - // Caller guarantees that the input is of type `Db` - // (we test it in the debug-assertion above). - let db = unsafe { transmute_data_ptr::(db) }; - (self.0)(db) - } - - /// A type-erased version of `ViewCaster::::cast`. - /// - /// # Safety - /// - /// The underlying type of `caster` must be `ViewCaster::`. - unsafe fn erased_cast(caster: *mut (), db: &dyn Database) -> &DbView { - let caster = unsafe { &*caster.cast::>() }; - unsafe { caster.cast(db) } - } - - /// The destructor for `Box>`. - /// - /// # Safety - /// - /// All the safety requirements of `Box::>::from_raw` apply. - unsafe fn erased_drop(caster: *mut ()) { - let _: Box> = unsafe { Box::from_raw(caster.cast()) }; + panic!( + "No downcaster registered for type `{}` in `Views`", + std::any::type_name::(), + ); } } @@ -190,7 +130,7 @@ impl std::fmt::Debug for Views { } } -impl std::fmt::Debug for DynViewCaster { +impl std::fmt::Debug for ViewCaster { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("DynViewCaster") .field(&self.type_name) diff --git a/src/zalsa.rs b/src/zalsa.rs index 2069ce52c..a0286b428 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -179,10 +179,6 @@ impl Zalsa { } } - pub(crate) fn views(&self) -> &Views { - &self.views_of - } - pub(crate) fn nonce(&self) -> Nonce { self.nonce } @@ -265,6 +261,11 @@ impl Zalsa { /// Semver unstable APIs used by the macro expansions impl Zalsa { + /// **NOT SEMVER STABLE** + pub fn views(&self) -> &Views { + &self.views_of + } + /// **NOT SEMVER STABLE** #[inline] pub fn lookup_page_type_id(&self, id: Id) -> TypeId { @@ -471,10 +472,10 @@ where /// Given a wide pointer `T`, extracts the data pointer (typed as `U`). /// -/// # Safety requirement +/// # Safety /// /// `U` must be correct type for the data pointer. -pub(crate) unsafe fn transmute_data_ptr(t: &T) -> &U { +pub unsafe fn transmute_data_ptr(t: &T) -> &U { let t: *const T = t; let u: *const U = t as *const U; unsafe { &*u } @@ -482,7 +483,7 @@ pub(crate) unsafe fn transmute_data_ptr(t: &T) -> &U { /// Given a wide pointer `T`, extracts the data pointer (typed as `U`). /// -/// # Safety requirement +/// # Safety /// /// `U` must be correct type for the data pointer. pub(crate) unsafe fn transmute_data_mut_ptr(t: &mut T) -> &mut U {