diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index efc27398f..b35e32abe 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -254,21 +254,45 @@ macro_rules! setup_tracked_fn { struct_index } }; - let memo_ingredient_indices = From::from((zalsa, struct_index, first_index)); - let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( - first_index, - memo_ingredient_indices, - $lru, - zalsa.views().downcaster_for::() - ); + $zalsa::macro_if! { $needs_interner => + let intern_ingredient = <$zalsa::interned::IngredientImpl<$Configuration>>::new( + first_index.successor(0) + ); + } + + let intern_ingredient_memo_types = $zalsa::macro_if! { + if $needs_interner { + Some($zalsa::Ingredient::memo_table_types(&intern_ingredient)) + } else { + None + } + }; + // SAFETY: We call with the correct memo types. + let memo_ingredient_indices = unsafe { + $zalsa::NewMemoIngredientIndices::create( + zalsa, + struct_index, + first_index, + $zalsa::function::MemoEntryType::of::<$zalsa::function::Memo<$Configuration>>(), + intern_ingredient_memo_types, + ) + }; + + // SAFETY: We pass the MemoEntryType for this Configuration, and we lookup the memo types table correctly. + let fn_ingredient = unsafe { + <$zalsa::function::IngredientImpl<$Configuration>>::new( + first_index, + memo_ingredient_indices, + $lru, + zalsa.views().downcaster_for::(), + ) + }; $zalsa::macro_if! { if $needs_interner { vec![ Box::new(fn_ingredient), - Box::new(<$zalsa::interned::IngredientImpl<$Configuration>>::new( - first_index.successor(0) - )), + Box::new(intern_ingredient), ] } else { vec![ diff --git a/src/accumulator.rs b/src/accumulator.rs index c01e487de..819767bd3 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -4,12 +4,14 @@ use std::any::{Any, TypeId}; use std::fmt; use std::marker::PhantomData; use std::panic::UnwindSafe; +use std::sync::Arc; use accumulated::{Accumulated, AnyAccumulated}; use crate::function::VerifyResult; use crate::ingredient::{fmt_index, Ingredient, Jar}; use crate::plumbing::IngredientIndices; +use crate::table::memo::MemoTableTypes; use crate::zalsa::{IngredientIndex, Zalsa}; use crate::{Database, Id, Revision}; @@ -110,6 +112,10 @@ impl Ingredient for IngredientImpl { fn debug_name(&self) -> &'static str { A::DEBUG_NAME } + + fn memo_table_types(&self) -> Arc { + unreachable!("accumulator does not allocate pages") + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/function.rs b/src/function.rs index d483323c2..49d6e7f5d 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,6 +1,7 @@ use std::any::Any; use std::fmt; use std::ptr::NonNull; +use std::sync::Arc; pub(crate) use maybe_changed_after::VerifyResult; @@ -11,6 +12,7 @@ use crate::ingredient::{fmt_index, Ingredient}; use crate::key::DatabaseKeyIndex; use crate::plumbing::MemoIngredientMap; use crate::salsa_struct::SalsaStructInDb; +use crate::table::memo::MemoTableTypes; use crate::table::sync::ClaimResult; use crate::table::Table; use crate::views::DatabaseDownCaster; @@ -30,6 +32,8 @@ mod maybe_changed_after; mod memo; mod specify; +pub type Memo = memo::Memo<::Output<'static>>; + pub trait Configuration: Any { const DEBUG_NAME: &'static str; @@ -142,7 +146,10 @@ impl IngredientImpl where C: Configuration, { - pub fn new( + /// # Safety + /// + /// `memo_type` and `memo_table_types` must be correct. + pub unsafe fn new( index: IngredientIndex, memo_ingredient_indices: as SalsaStructInDb>::MemoIngredientMap, lru: usize, @@ -314,6 +321,10 @@ where C::DEBUG_NAME } + fn memo_table_types(&self) -> Arc { + unreachable!("function does not allocate pages") + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { C::CYCLE_STRATEGY } diff --git a/src/function/memo.rs b/src/function/memo.rs index eaa315cb2..5fe0a5432 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -9,7 +9,7 @@ use crate::cycle::{CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; use crate::revision::AtomicRevision; -use crate::table::memo::MemoTable; +use crate::table::memo::MemoTableWithTypesMut; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryOrigin, QueryRevisions}; use crate::{Event, EventKind, Id, Revision}; @@ -84,7 +84,7 @@ impl IngredientImpl { /// with an equivalent memo that has no value. If the memo is untracked, FixpointInitial, /// or has values assigned as output of another query, this has no effect. pub(super) fn evict_value_from_memo_for( - table: &mut MemoTable, + table: MemoTableWithTypesMut<'_>, memo_ingredient_index: MemoIngredientIndex, ) { let map = |memo: &mut Memo>| { @@ -120,7 +120,7 @@ impl IngredientImpl { } #[derive(Debug)] -pub(super) struct Memo { +pub struct Memo { /// The result of the query, if we decide to memoize it. pub(super) value: Option, diff --git a/src/ingredient.rs b/src/ingredient.rs index a62e660de..418735e26 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,10 +1,12 @@ use std::any::{Any, TypeId}; use std::fmt; +use std::sync::Arc; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; use crate::cycle::CycleRecoveryStrategy; use crate::function::VerifyResult; use crate::plumbing::IngredientIndices; +use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOrigin; @@ -132,6 +134,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { ); } + fn memo_table_types(&self) -> Arc; + fn fmt_index(&self, index: crate::Id, fmt: &mut fmt::Formatter<'_>) -> fmt::Result; // Function ingredient methods diff --git a/src/input.rs b/src/input.rs index 4d4405d74..4802860e5 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,6 +1,7 @@ use std::any::{Any, TypeId}; use std::fmt; use std::ops::DerefMut; +use std::sync::Arc; pub mod input_field; pub mod setter; @@ -14,7 +15,7 @@ use crate::ingredient::{fmt_index, Ingredient}; use crate::input::singleton::{Singleton, SingletonChoice}; use crate::key::DatabaseKeyIndex; use crate::plumbing::{Jar, Stamp}; -use crate::table::memo::MemoTable; +use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::sync::SyncTable; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; @@ -72,6 +73,7 @@ impl Jar for JarImpl { pub struct IngredientImpl { ingredient_index: IngredientIndex, singleton: C::Singleton, + memo_table_types: Arc, _phantom: std::marker::PhantomData, } @@ -80,6 +82,7 @@ impl IngredientImpl { Self { ingredient_index: index, singleton: Default::default(), + memo_table_types: Arc::new(MemoTableTypes::default()), _phantom: std::marker::PhantomData, } } @@ -100,7 +103,7 @@ impl IngredientImpl { let (zalsa, zalsa_local) = db.zalsas(); let id = self.singleton.with_scope(|| { - zalsa_local.allocate(zalsa.table(), self.ingredient_index, |_| Value:: { + zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { fields, stamps, memos: Default::default(), @@ -219,6 +222,10 @@ impl Ingredient for IngredientImpl { fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn memo_table_types(&self) -> Arc { + self.memo_table_types.clone() + } } impl std::fmt::Debug for IngredientImpl { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 17aef7044..a987eb909 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,9 +1,11 @@ use std::fmt; use std::marker::PhantomData; +use std::sync::Arc; use crate::function::VerifyResult; use crate::ingredient::{fmt_index, Ingredient}; use crate::input::{Configuration, IngredientImpl, Value}; +use crate::table::memo::MemoTableTypes; use crate::zalsa::IngredientIndex; use crate::{Database, Id, Revision}; @@ -69,6 +71,10 @@ where fn debug_name(&self) -> &'static str { C::FIELD_DEBUG_NAMES[self.field_index] } + + fn memo_table_types(&self) -> Arc { + unreachable!("input fields do not allocate pages") + } } impl std::fmt::Debug for FieldIngredientImpl diff --git a/src/interned.rs b/src/interned.rs index 7be6a3420..0b0cb808d 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -16,7 +16,7 @@ use crate::hash::FxDashMap; use crate::ingredient::{fmt_index, Ingredient}; use crate::plumbing::{IngredientIndices, Jar}; use crate::revision::AtomicRevision; -use crate::table::memo::MemoTable; +use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::sync::SyncTable; use crate::table::Slot; use crate::zalsa::{IngredientIndex, Zalsa}; @@ -62,6 +62,8 @@ pub struct IngredientImpl { /// /// Deadlock requirement: We access `value_map` while holding lock on `key_map`, but not vice versa. key_map: FxDashMap, Id>, + + memo_table_types: Arc, } /// Struct storing the interned fields. @@ -132,6 +134,7 @@ where Self { ingredient_index, key_map: Default::default(), + memo_table_types: Arc::new(MemoTableTypes::default()), } } @@ -279,8 +282,6 @@ where // We won any races so should intern the data Err(slot) => { - let table = zalsa.table(); - // Record the durability of the current query on the interned value. let durability = zalsa_local .active_query() @@ -288,7 +289,7 @@ where // If there is no active query this durability does not actually matter. .unwrap_or(Durability::MAX); - let id = zalsa_local.allocate(table, self.ingredient_index, |id| Value:: { + let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value:: { fields: unsafe { self.to_internal_data(assemble(id, key)) }, memos: Default::default(), syncs: Default::default(), @@ -298,7 +299,7 @@ where last_interned_at: AtomicRevision::from(current_revision), }); - let value = table.get::>(id); + let value = zalsa.table().get::>(id); let slot_value = (value.fields.clone(), SharedValue::new(id)); unsafe { lock.insert_in_slot(data_hash, slot, slot_value) }; @@ -307,7 +308,7 @@ where data_hash, self.key_map .hasher() - .hash_one(table.get::>(id).fields.clone()) + .hash_one(zalsa.table().get::>(id).fields.clone()) ); // Record a dependency on this value. @@ -409,6 +410,10 @@ where fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn memo_table_types(&self) -> Arc { + self.memo_table_types.clone() + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/lib.rs b/src/lib.rs index 48676b83f..d69c59c17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -87,6 +87,7 @@ pub mod plumbing { pub use crate::key::DatabaseKeyIndex; pub use crate::memo_ingredient_indices::{ IngredientIndices, MemoIngredientIndices, MemoIngredientMap, MemoIngredientSingletonIndex, + NewMemoIngredientIndices, }; pub use crate::revision::Revision; pub use crate::runtime::{stamp, Runtime, Stamp, StampedValue}; @@ -118,7 +119,10 @@ pub mod plumbing { } pub mod function { - pub use crate::function::{Configuration, IngredientImpl}; + pub use crate::function::Configuration; + pub use crate::function::IngredientImpl; + pub use crate::function::Memo; + pub use crate::table::memo::MemoEntryType; } pub mod tracked_struct { diff --git a/src/memo_ingredient_indices.rs b/src/memo_ingredient_indices.rs index a784b4ea4..6b2fc9c23 100644 --- a/src/memo_ingredient_indices.rs +++ b/src/memo_ingredient_indices.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use crate::table::memo::{MemoEntryType, MemoTableTypes}; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::{Id, IngredientIndex}; @@ -42,11 +45,34 @@ impl IngredientIndices { } } -impl From<(&Zalsa, IngredientIndices, IngredientIndex)> for MemoIngredientIndices { - #[inline] - fn from( - (zalsa, struct_indices, ingredient): (&Zalsa, IngredientIndices, IngredientIndex), +pub trait NewMemoIngredientIndices { + /// # Safety + /// + /// The memo types must be correct. + unsafe fn create( + zalsa: &Zalsa, + struct_indices: IngredientIndices, + ingredient: IngredientIndex, + memo_type: MemoEntryType, + intern_ingredient_memo_types: Option>, + ) -> Self; +} + +impl NewMemoIngredientIndices for MemoIngredientIndices { + /// # Safety + /// + /// The memo types must be correct. + unsafe fn create( + zalsa: &Zalsa, + struct_indices: IngredientIndices, + ingredient: IngredientIndex, + memo_type: MemoEntryType, + _intern_ingredient_memo_types: Option>, ) -> Self { + debug_assert!( + _intern_ingredient_memo_types.is_none(), + "intern ingredient can only have a singleton memo ingredient" + ); let Some(&last) = struct_indices.indices.last() else { unreachable!("Attempting to construct struct memo mapping for non tracked function?") }; @@ -56,8 +82,14 @@ impl From<(&Zalsa, IngredientIndices, IngredientIndex)> for MemoIngredientIndice MemoIngredientIndex::from_usize((u32::MAX - 1) as usize), ); for &struct_ingredient in &struct_indices.indices { - indices[struct_ingredient.as_usize()] = - zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); + let memo_types = zalsa + .lookup_ingredient(struct_ingredient) + .memo_table_types(); + + let mi = zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); + memo_types.set(mi, &memo_type); + + indices[struct_ingredient.as_usize()] = mi; } MemoIngredientIndices { indices: indices.into_boxed_slice(), @@ -112,14 +144,28 @@ impl MemoIngredientMap for MemoIngredientSingletonIndex { } } -impl From<(&Zalsa, IngredientIndices, IngredientIndex)> for MemoIngredientSingletonIndex { +impl NewMemoIngredientIndices for MemoIngredientSingletonIndex { #[inline] - fn from((zalsa, indices, ingredient): (&Zalsa, IngredientIndices, IngredientIndex)) -> Self { + unsafe fn create( + zalsa: &Zalsa, + indices: IngredientIndices, + ingredient: IngredientIndex, + memo_type: MemoEntryType, + intern_ingredient_memo_types: Option>, + ) -> Self { let &[struct_ingredient] = &*indices.indices else { unreachable!("Attempting to construct struct memo mapping from enum?") }; - Self(zalsa.next_memo_ingredient_index(struct_ingredient, ingredient)) + let memo_types = intern_ingredient_memo_types.unwrap_or_else(|| { + zalsa + .lookup_ingredient(struct_ingredient) + .memo_table_types() + }); + + let mi = zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); + memo_types.set(mi, &memo_type); + Self(mi) } } diff --git a/src/runtime.rs b/src/runtime.rs index c1c31184a..e9450840b 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -133,6 +133,7 @@ impl Runtime { } /// Returns the [`Table`] used to store the value of salsa structs + #[inline] pub(crate) fn table(&self) -> &Table { &self.table } diff --git a/src/table.rs b/src/table.rs index 768746ca5..113b3146e 100644 --- a/src/table.rs +++ b/src/table.rs @@ -6,12 +6,14 @@ use std::mem::{self, MaybeUninit}; use std::ptr::{self, NonNull}; use std::slice; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use memo::MemoTable; use parking_lot::Mutex; use rustc_hash::FxHashMap; use sync::SyncTable; +use crate::table::memo::{MemoTableTypes, MemoTableWithTypes, MemoTableWithTypesMut}; use crate::{Id, IngredientIndex, Revision}; pub(crate) mod memo; @@ -71,20 +73,23 @@ struct SlotVTable { memos_mut: SlotMemosMutFnRaw, syncs: SlotSyncsFnRaw, /// A drop impl to call when the own page drops - /// SAFETY: The caller is required to supply a correct data pointer to a `Box>` and initialized length - drop_impl: unsafe fn(data: *mut (), initialized: usize), + /// SAFETY: The caller is required to supply a correct data pointer to a `Box>` and initialized length, + /// and correct memo types. + drop_impl: unsafe fn(data: *mut (), initialized: usize, memo_types: &MemoTableTypes), } impl SlotVTable { const fn of() -> &'static Self { const { &Self { - drop_impl: |data, initialized| + drop_impl: |data, initialized, memo_types| // SAFETY: The caller is required to supply a correct data pointer and initialized length unsafe { let data = Box::from_raw(data.cast::>()); for i in 0..initialized { - ptr::drop_in_place(data[i].get().cast::()); + let item = data[i].get().cast::(); + memo_types.attach_memos_mut((*item).memos_mut()).drop(); + ptr::drop_in_place(item); } }, layout: Layout::new::(), @@ -133,6 +138,8 @@ struct Page { /// The type name of what is stored as entries in data. // FIXME: Move this into SlotVTable once const stable slot_type_name: &'static str, + + memo_types: Arc, } // SAFETY: `Page` is `Send` as we make sure to only ever store `Slot` types in it which @@ -146,6 +153,7 @@ unsafe impl Sync for Page {} pub struct PageIndex(usize); impl PageIndex { + #[inline] fn new(idx: usize) -> Self { debug_assert!(idx < MAX_PAGES); Self(idx) @@ -215,8 +223,13 @@ impl Table { } /// Allocate a new page for the given ingredient and with slots of type `T` - pub(crate) fn push_page(&self, ingredient: IngredientIndex) -> PageIndex { - PageIndex::new(self.pages.push(Page::new::(ingredient))) + #[inline] + pub(crate) fn push_page( + &self, + ingredient: IngredientIndex, + memo_types: Arc, + ) -> PageIndex { + PageIndex::new(self.pages.push(Page::new::(ingredient, memo_types))) } /// Get the memo table associated with `id` @@ -225,15 +238,21 @@ impl Table { /// /// The parameter `current_revision` MUST be the current revision /// of the owner of database owning this table. - pub(crate) unsafe fn memos(&self, id: Id, current_revision: Revision) -> &MemoTable { + pub(crate) unsafe fn memos( + &self, + id: Id, + current_revision: Revision, + ) -> MemoTableWithTypes<'_> { let (page, slot) = split_id(id); let page = &self.pages[page.0]; // SAFETY: We supply a proper slot pointer and the caller is required to pass the `current_revision`. - unsafe { &*(page.slot_vtable.memos)(page.get(slot), current_revision) } + let memos = unsafe { &*(page.slot_vtable.memos)(page.get(slot), current_revision) }; + // SAFETY: The `Page` keeps the correct memo types. + unsafe { page.memo_types.attach_memos(memos) } } /// Get the memo table associated with `id` - pub(crate) fn memos_mut(&mut self, id: Id) -> &mut MemoTable { + pub(crate) fn memos_mut(&mut self, id: Id) -> MemoTableWithTypesMut<'_> { let (page, slot) = split_id(id); let page_index = page.0; let page = self @@ -241,7 +260,9 @@ impl Table { .get_mut(page_index) .unwrap_or_else(|| panic!("index `{page_index}` is uninitialized")); // SAFETY: We supply a proper slot pointer and the caller is required to pass the `current_revision`. - unsafe { &mut *(page.slot_vtable.memos_mut)(page.get(slot)) } + let memos = unsafe { &mut *(page.slot_vtable.memos_mut)(page.get(slot)) }; + // SAFETY: The `Page` keeps the correct memo types. + unsafe { page.memo_types.attach_memos_mut(memos) } } /// Get the sync table associated with `id` @@ -264,7 +285,11 @@ impl Table { .flat_map(|view| view.data()) } - pub(crate) fn fetch_or_push_page(&self, ingredient: IngredientIndex) -> PageIndex { + pub(crate) fn fetch_or_push_page( + &self, + ingredient: IngredientIndex, + memo_types: impl FnOnce() -> Arc, + ) -> PageIndex { if let Some(page) = self .non_full_pages .lock() @@ -273,7 +298,7 @@ impl Table { { return page; } - self.push_page::(ingredient) + self.push_page::(ingredient, memo_types()) } pub(crate) fn record_unfilled_page(&self, ingredient: IngredientIndex, page: PageIndex) { @@ -325,7 +350,8 @@ impl<'p, T: Slot> PageView<'p, T> { } impl Page { - fn new(ingredient: IngredientIndex) -> Self { + #[inline] + fn new(ingredient: IngredientIndex, memo_types: Arc) -> Self { let data: Box> = Box::new([const { UnsafeCell::new(MaybeUninit::uninit()) }; PAGE_LEN]); Self { @@ -336,6 +362,7 @@ impl Page { allocated: Default::default(), allocation_lock: Default::default(), data: NonNull::from(Box::leak(data)).cast::<()>(), + memo_types, } } @@ -382,7 +409,7 @@ impl Drop for Page { fn drop(&mut self) { let &mut len = self.allocated.get_mut(); // SAFETY: We supply the data pointer and the initialized length - unsafe { (self.slot_vtable.drop_impl)(self.data.as_ptr(), len) }; + unsafe { (self.slot_vtable.drop_impl)(self.data.as_ptr(), len, &self.memo_types) }; } } diff --git a/src/table/memo.rs b/src/table/memo.rs index 07e3085d7..5dbe3c9e6 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -1,12 +1,18 @@ -use std::any::{Any, TypeId}; -use std::ptr::NonNull; -use std::sync::atomic::{AtomicPtr, Ordering}; +use std::{ + any::{Any, TypeId}, + fmt::Debug, + mem, + ptr::{self, NonNull}, + sync::{ + atomic::{AtomicPtr, Ordering}, + OnceLock, + }, +}; use parking_lot::RwLock; use thin_vec::ThinVec; -use crate::zalsa::MemoIngredientIndex; -use crate::zalsa_local::QueryOrigin; +use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOrigin}; /// The "memo table" stores the memoized results of tracked function calls. /// Every tracked function must take a salsa struct as its first argument @@ -16,19 +22,11 @@ pub(crate) struct MemoTable { memos: RwLock>, } -pub(crate) trait Memo: Any + Send + Sync { +pub trait Memo: Any + Send + Sync { /// Returns the `origin` of this memo fn origin(&self) -> &QueryOrigin; } -/// Wraps the data stored for a memoized entry. -/// This struct has a customized Drop that will -/// ensure that its `data` field is properly freed. -#[derive(Default)] -struct MemoEntry { - data: Option, -} - /// Data for a memoized entry. /// This is a type-erased `Box`, where `M` is the type of memo associated /// with that particular ingredient index. @@ -46,21 +44,26 @@ struct MemoEntry { /// Therefore, we hide the type by transmuting to `DummyMemo`; but we must then be very careful /// when freeing `MemoEntryData` values to transmute things back. See the `Drop` impl for /// [`MemoEntry`][] for details. -struct MemoEntryData { +#[derive(Default)] +struct MemoEntry { + /// An [`AtomicPtr`][] to a `Box` for the erased memo type `M` + atomic_memo: AtomicPtr, +} + +pub struct MemoEntryType { + data: OnceLock, +} + +#[derive(Clone, Copy)] +struct MemoEntryTypeData { /// The `type_id` of the erased memo type `M` type_id: TypeId, /// A type-coercion function for the erased memo type `M` to_dyn_fn: fn(NonNull) -> NonNull, - - /// An [`AtomicPtr`][] to a `Box` for the erased memo type `M` - atomic_memo: AtomicPtr, } -/// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][]. -struct DummyMemo {} - -impl MemoTable { +impl MemoEntryType { fn to_dummy(memo: NonNull) -> NonNull { memo.cast() } @@ -69,51 +72,138 @@ impl MemoTable { memo.cast() } - fn to_dyn_fn() -> fn(NonNull) -> NonNull { + const fn to_dyn_fn() -> fn(NonNull) -> NonNull { let f: fn(NonNull) -> NonNull = |x| x; #[allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety unsafe { - std::mem::transmute::< + mem::transmute::< fn(NonNull) -> NonNull, fn(NonNull) -> NonNull, >(f) } } + #[inline] + pub fn of() -> Self { + Self { + data: OnceLock::from(MemoEntryTypeData { + type_id: TypeId::of::(), + to_dyn_fn: Self::to_dyn_fn::(), + }), + } + } + + #[inline] + fn load(&self) -> Option<&MemoEntryTypeData> { + self.data.get() + } +} + +/// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][]. +#[derive(Debug)] +struct DummyMemo {} + +impl Memo for DummyMemo { + fn origin(&self) -> &QueryOrigin { + unreachable!("should not get here") + } +} + +#[derive(Default)] +pub struct MemoTableTypes { + types: boxcar::Vec, +} + +impl MemoTableTypes { + pub(crate) fn set( + &self, + memo_ingredient_index: MemoIngredientIndex, + memo_type: &MemoEntryType, + ) { + let memo_ingredient_index = memo_ingredient_index.as_usize(); + while memo_ingredient_index >= self.types.count() { + self.types.push(MemoEntryType { + data: OnceLock::new(), + }); + } + let memo_entry_type = self.types.get(memo_ingredient_index).unwrap(); + memo_entry_type + .data + .set( + *memo_type + .data + .get() + .expect("cannot provide an empty `MemoEntryType` for `MemoEntryType::set()`"), + ) + .ok() + .expect("memo type should only be set once"); + } + /// # Safety /// - /// The caller needs to make sure to not free the returned value until no more references into - /// the database exist as there may be outstanding borrows into the pointer contents. + /// The types table must be the correct one of `memos`. + #[inline] + pub(crate) unsafe fn attach_memos<'a>( + &'a self, + memos: &'a MemoTable, + ) -> MemoTableWithTypes<'a> { + MemoTableWithTypes { types: self, memos } + } + + /// # Safety + /// + /// The types table must be the correct one of `memos`. + #[inline] + pub(crate) unsafe fn attach_memos_mut<'a>( + &'a self, + memos: &'a mut MemoTable, + ) -> MemoTableWithTypesMut<'a> { + MemoTableWithTypesMut { types: self, memos } + } +} + +pub(crate) struct MemoTableWithTypes<'a> { + types: &'a MemoTableTypes, + memos: &'a MemoTable, +} + +impl<'a> MemoTableWithTypes<'a> { + /// # Safety + /// + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. pub(crate) unsafe fn insert( - &self, + self, memo_ingredient_index: MemoIngredientIndex, memo: NonNull, ) -> Option> { + // The type must already exist, we insert it when creating the memo ingredient. + assert_eq!( + self.types + .types + .get(memo_ingredient_index.as_usize()) + .and_then(MemoEntryType::load)? + .type_id, + TypeId::of::(), + "inconsistent type-id for `{memo_ingredient_index:?}`" + ); + // If the memo slot is already occupied, it must already have the // right type info etc, and we only need the read-lock. - if let Some(MemoEntry { - data: - Some(MemoEntryData { - type_id, - to_dyn_fn: _, - atomic_memo, - }), - }) = self.memos.read().get(memo_ingredient_index.as_usize()) + if let Some(MemoEntry { atomic_memo }) = self + .memos + .memos + .read() + .get(memo_ingredient_index.as_usize()) { - assert_eq!( - *type_id, - TypeId::of::(), - "inconsistent type-id for `{memo_ingredient_index:?}`" - ); + let old_memo = + atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel); - let old_memo = atomic_memo.swap(Self::to_dummy(memo).as_ptr(), Ordering::AcqRel); - - // SAFETY: The `atomic_memo` field is never null. - let old_memo = unsafe { NonNull::new_unchecked(old_memo) }; + let old_memo = NonNull::new(old_memo); // SAFETY: `type_id` check asserted above - return Some(unsafe { Self::from_dummy(old_memo) }); + return old_memo.map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) }); } // Otherwise we need the write lock. @@ -123,142 +213,139 @@ impl MemoTable { /// # Safety /// - /// The caller needs to make sure to not free the returned value until no more references into - /// the database exist as there may be outstanding borrows into the pointer contents. + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. unsafe fn insert_cold( - &self, + self, memo_ingredient_index: MemoIngredientIndex, memo: NonNull, ) -> Option> { - let mut memos = self.memos.write(); let memo_ingredient_index = memo_ingredient_index.as_usize(); + let mut memos = self.memos.memos.write(); + let additional_len = memo_ingredient_index - memos.len() + 1; + memos.reserve(additional_len); while memos.len() < memo_ingredient_index + 1 { - memos.push(MemoEntry { data: None }); + memos.push(MemoEntry::default()); } - let old_entry = memos[memo_ingredient_index].data.replace(MemoEntryData { - type_id: TypeId::of::(), - to_dyn_fn: Self::to_dyn_fn::(), - atomic_memo: AtomicPtr::new(Self::to_dummy(memo).as_ptr()), - }); - old_entry.map( - |MemoEntryData { - type_id: _, - to_dyn_fn: _, - atomic_memo, - }| - // SAFETY: The `atomic_memo` field is never null. - unsafe { Self::from_dummy(NonNull::new_unchecked(atomic_memo.into_inner())) }, - ) + let old_entry = mem::replace( + memos[memo_ingredient_index].atomic_memo.get_mut(), + MemoEntryType::to_dummy(memo).as_ptr(), + ); + let old_entry = NonNull::new(old_entry); + // SAFETY: The `TypeId` is asserted in `insert()`. + old_entry.map(|memo| unsafe { MemoEntryType::from_dummy(memo) }) } - pub(crate) fn get(&self, memo_ingredient_index: MemoIngredientIndex) -> Option<&M> { - let memos = self.memos.read(); - - let Some(MemoEntry { - data: - Some(MemoEntryData { - type_id, - to_dyn_fn: _, - atomic_memo, - }), - }) = memos.get(memo_ingredient_index.as_usize()) - else { - return None; - }; - + #[inline] + pub(crate) fn get(self, memo_ingredient_index: MemoIngredientIndex) -> Option<&'a M> { + let read = self.memos.memos.read(); + let memo = read.get(memo_ingredient_index.as_usize())?; + let type_ = self + .types + .types + .get(memo_ingredient_index.as_usize()) + .and_then(MemoEntryType::load)?; assert_eq!( - *type_id, + type_.type_id, TypeId::of::(), "inconsistent type-id for `{memo_ingredient_index:?}`" ); - - // SAFETY: The `atomic_memo` field is never null. - let memo = unsafe { NonNull::new_unchecked(atomic_memo.load(Ordering::Acquire)) }; - + let memo = NonNull::new(memo.atomic_memo.load(Ordering::Acquire)); // SAFETY: `type_id` check asserted above - unsafe { Some(Self::from_dummy(memo).as_ref()) } + memo.map(|memo| unsafe { MemoEntryType::from_dummy(memo).as_ref() }) } +} +pub(crate) struct MemoTableWithTypesMut<'a> { + types: &'a MemoTableTypes, + memos: &'a mut MemoTable, +} + +impl MemoTableWithTypesMut<'_> { /// Calls `f` on the memo at `memo_ingredient_index`. /// /// If the memo is not present, `f` is not called. pub(crate) fn map_memo( - &mut self, + self, memo_ingredient_index: MemoIngredientIndex, f: impl FnOnce(&mut M), ) { - let memos = self.memos.get_mut(); - let Some(MemoEntry { - data: - Some(MemoEntryData { - type_id, - to_dyn_fn: _, - atomic_memo, - }), - }) = memos.get_mut(memo_ingredient_index.as_usize()) + let Some(type_) = self + .types + .types + .get(memo_ingredient_index.as_usize()) + .and_then(MemoEntryType::load) else { return; }; - assert_eq!( - *type_id, + type_.type_id, TypeId::of::(), "inconsistent type-id for `{memo_ingredient_index:?}`" ); - // SAFETY: The `atomic_memo` field is never null. - let memo = unsafe { NonNull::new_unchecked(*atomic_memo.get_mut()) }; + // If the memo slot is already occupied, it must already have the + // right type info etc, and we only need the read-lock. + let memos = self.memos.memos.get_mut(); + let Some(MemoEntry { atomic_memo }) = memos.get_mut(memo_ingredient_index.as_usize()) + else { + return; + }; + let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else { + return; + }; // SAFETY: `type_id` check asserted above - f(unsafe { Self::from_dummy(memo).as_mut() }); + f(unsafe { MemoEntryType::from_dummy(memo).as_mut() }); + } + + /// To drop an entry, we need its type, so we don't implement `Drop`, and instead have this method. + #[inline] + pub fn drop(self) { + let types = self.types.types.iter(); + for ((_, type_), memo) in std::iter::zip(types, self.memos.memos.get_mut()) { + // SAFETY: The types match because this is an invariant of `MemoTableWithTypesMut`. + unsafe { memo.drop(type_) }; + } } /// # Safety /// /// The caller needs to make sure to not call this function until no more references into /// the database exist as there may be outstanding borrows into the pointer contents. - pub(crate) unsafe fn into_memos( - self, - ) -> impl Iterator)> { - self.memos - .into_inner() - .into_iter() + pub(crate) unsafe fn with_memos(self, mut f: impl FnMut(MemoIngredientIndex, Box)) { + let memos = self.memos.memos.get_mut(); + memos + .iter_mut() + .zip(self.types.types.iter()) .zip(0..) - .filter_map(|(mut memo, index)| memo.data.take().map(|d| (d, index))) - .map( - |( - MemoEntryData { - type_id: _, - to_dyn_fn, - atomic_memo, - }, - index, - )| { - // SAFETY: The `atomic_memo` field is never null. - let memo = - unsafe { to_dyn_fn(NonNull::new_unchecked(atomic_memo.into_inner())) }; - // SAFETY: The caller guarantees that there are no outstanding borrows into the `Box` contents. - let memo = unsafe { Box::from_raw(memo.as_ptr()) }; - - (MemoIngredientIndex::from_usize(index), memo) - }, - ) + .filter_map(|((memo, (_, type_)), index)| { + let memo = mem::replace(memo.atomic_memo.get_mut(), ptr::null_mut()); + let memo = NonNull::new(memo)?; + Some((memo, type_.load()?, index)) + }) + .map(|(memo, type_, index)| { + // SAFETY: We took ownership of the memo, and converted it to the correct type. + // The caller guarantees that there are no outstanding borrows into the `Box` contents. + let memo = unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) }; + (MemoIngredientIndex::from_usize(index), memo) + }) + .for_each(|(index, memo)| f(index, memo)); } } -impl Drop for MemoEntry { - fn drop(&mut self) { - if let Some(MemoEntryData { - type_id: _, - to_dyn_fn, - atomic_memo, - }) = self.data.take() +impl MemoEntry { + /// # Safety + /// + /// The type must match. + #[inline] + unsafe fn drop(&mut self, type_: &MemoEntryType) { + if let Some(memo) = NonNull::new(mem::replace(self.atomic_memo.get_mut(), ptr::null_mut())) { - // SAFETY: The `atomic_memo` field is never null. - let memo = unsafe { to_dyn_fn(NonNull::new_unchecked(atomic_memo.into_inner())) }; - // SAFETY: We have `&mut self`, so there are no outstanding borrows into the `Box` contents. - let memo = unsafe { Box::from_raw(memo.as_ptr()) }; - std::mem::drop(memo); + if let Some(type_) = type_.load() { + // SAFETY: Our preconditions. + mem::drop(unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) }); + } } } } @@ -271,6 +358,6 @@ impl Drop for DummyMemo { impl std::fmt::Debug for MemoTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MemoTable").finish() + f.debug_struct("MemoTable").finish_non_exhaustive() } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index dc835daad..ec389ecbd 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -5,6 +5,7 @@ use std::fmt; use std::hash::Hash; use std::marker::PhantomData; use std::ops::DerefMut; +use std::sync::Arc; use crossbeam_queue::SegQueue; use tracked_field::FieldIngredientImpl; @@ -16,7 +17,7 @@ use crate::plumbing::ZalsaLocal; use crate::revision::OptionalAtomicRevision; use crate::runtime::StampedValue; use crate::salsa_struct::SalsaStructInDb; -use crate::table::memo::MemoTable; +use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::sync::SyncTable; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; @@ -168,6 +169,8 @@ where /// Store freed ids free_list: SegQueue, + + memo_table_types: Arc, } /// Defines the identity of a tracked struct. @@ -385,6 +388,7 @@ where ingredient_index: index, phantom: PhantomData, free_list: Default::default(), + memo_table_types: Arc::new(MemoTableTypes::default()), } } @@ -466,7 +470,7 @@ where id } else { - zalsa_local.allocate::>(zalsa.table(), self.ingredient_index, value) + zalsa_local.allocate::>(zalsa, self.ingredient_index, value) } } @@ -629,22 +633,33 @@ where // Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None` // and the code that references the memo-table has a read-lock. - let memo_table = unsafe { (*data).take_memo_table() }; - + struct MemoTableWithTypes<'a>(MemoTable, &'a MemoTableTypes); + impl Drop for MemoTableWithTypes<'_> { + fn drop(&mut self) { + // SAFETY: We use the correct types table. + unsafe { self.1.attach_memos_mut(&mut self.0) }.drop(); + } + } + let mut memo_table = + MemoTableWithTypes(unsafe { (*data).take_memo_table() }, &self.memo_table_types); // SAFETY: We have verified that no more references to these memos exist and so we are good // to drop them. - for (memo_ingredient_index, memo) in unsafe { memo_table.into_memos() } { - let ingredient_index = - zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index); + unsafe { + memo_table.1.attach_memos_mut(&mut memo_table.0).with_memos( + |memo_ingredient_index, memo| { + let ingredient_index = zalsa + .ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index); - let executor = DatabaseKeyIndex::new(ingredient_index, id); + let executor = DatabaseKeyIndex::new(ingredient_index, id); - db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); + db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); - for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(zalsa, db, executor, provisional); - } - } + for stale_output in memo.origin().outputs() { + stale_output.remove_stale_output(zalsa, db, executor, provisional); + } + }, + ) + }; // now that all cleanup has occurred, make available for re-use self.free_list.push(id); @@ -790,6 +805,10 @@ where fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn memo_table_types(&self) -> Arc { + self.memo_table_types.clone() + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 03e435a88..8c09cbefa 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,7 +1,8 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; use crate::function::VerifyResult; use crate::ingredient::Ingredient; +use crate::table::memo::MemoTableTypes; use crate::tracked_struct::{Configuration, Value}; use crate::zalsa::IngredientIndex; use crate::{Database, Id}; @@ -24,7 +25,6 @@ where /// The absolute index of this field on the tracked struct. field_index: usize, - phantom: PhantomData Value>, } @@ -74,6 +74,10 @@ where fn debug_name(&self) -> &'static str { C::FIELD_DEBUG_NAMES[self.field_index] } + + fn memo_table_types(&self) -> Arc { + unreachable!("tracked field does not allocate pages") + } } impl std::fmt::Debug for FieldIngredientImpl diff --git a/src/zalsa.rs b/src/zalsa.rs index bb5716343..c21a713d3 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -13,7 +13,7 @@ use rustc_hash::FxHashMap; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; -use crate::table::memo::MemoTable; +use crate::table::memo::MemoTableWithTypes; use crate::table::sync::SyncTable; use crate::table::Table; use crate::views::Views; @@ -106,6 +106,7 @@ impl MemoIngredientIndex { MemoIngredientIndex(u as u32) } + #[inline] pub(crate) fn as_usize(self) -> usize { self.0 as usize } @@ -184,14 +185,16 @@ impl Zalsa { } /// Returns the [`Table`] used to store the value of salsa structs + #[inline] pub(crate) fn table(&self) -> &Table { self.runtime.table() } /// Returns the [`MemoTable`][] for the salsa struct with the given id - pub(crate) fn memo_table_for(&self, id: Id) -> &MemoTable { + pub(crate) fn memo_table_for(&self, id: Id) -> MemoTableWithTypes<'_> { + let table = self.table(); // SAFETY: We are supplying the correct current revision - unsafe { self.table().memos(id, self.current_revision()) } + unsafe { table.memos(id, self.current_revision()) } } /// Returns the [`SyncTable`][] for the salsa struct with the given id @@ -248,6 +251,7 @@ impl Zalsa { }; let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); memo_ingredients.push(ingredient_index); + mi } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 70d48e506..b3096a03d 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -13,7 +13,7 @@ use crate::key::DatabaseKeyIndex; use crate::runtime::Stamp; use crate::table::{PageIndex, Slot, Table}; use crate::tracked_struct::{Disambiguator, Identity, IdentityHash, IdentityMap}; -use crate::zalsa::IngredientIndex; +use crate::zalsa::{IngredientIndex, Zalsa}; use crate::{Accumulator, Cancelled, Id, Revision}; /// State that is specific to a single execution thread. @@ -54,20 +54,30 @@ impl ZalsaLocal { /// thread and attempts to reuse it. pub(crate) fn allocate( &self, - table: &Table, + zalsa: &Zalsa, ingredient: IngredientIndex, mut value: impl FnOnce(Id) -> T, ) -> Id { + let memo_types = || { + zalsa + .lookup_ingredient(ingredient) + .memo_table_types() + .clone() + }; // Find the most recent page, pushing a page if needed let mut page = *self .most_recent_pages .borrow_mut() .entry(ingredient) - .or_insert_with(|| table.fetch_or_push_page::(ingredient)); + .or_insert_with(|| { + zalsa + .table() + .fetch_or_push_page::(ingredient, memo_types) + }); loop { // Try to allocate an entry on that page - let page_ref = table.page::(page); + let page_ref = zalsa.table().page::(page); match page_ref.allocate(page, value) { // If successful, return Ok(id) => return id, @@ -77,7 +87,7 @@ impl ZalsaLocal { // it is unlikely that there is a non-full one available. Err(v) => { value = v; - page = table.push_page::(ingredient); + page = zalsa.table().push_page::(ingredient, memo_types()); self.most_recent_pages.borrow_mut().insert(ingredient, page); } }