diff --git a/Cargo.toml b/Cargo.toml index 0dc1bfefc..fca9ae8c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ rust-version = "1.76" [dependencies] arc-swap = "1" +compact_str = { version = "0.8", optional = true } crossbeam = "0.8" dashmap = { version = "6", features = ["raw-api"] } hashlink = "0.9" diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 439e4aa0d..68bc68e0a 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -55,6 +55,8 @@ macro_rules! setup_tracked_fn { // True if we `return_ref` flag was given to the function return_ref: $return_ref:tt, + maybe_update_fn: {$($maybe_update_fn:tt)*}, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -145,6 +147,14 @@ macro_rules! setup_tracked_fn { } } + /// This method isn't used anywhere. It only exitst to enforce the `Self::Output: Update` constraint + /// for types that aren't `'static`. + /// + /// # Safety + /// The same safety rules as for `Update` apply. + $($maybe_update_fn)* + + impl $zalsa::function::Configuration for $Configuration { const DEBUG_NAME: &'static str = stringify!($fn_name); diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 98db9a958..c74a265ff 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -117,6 +117,18 @@ impl Macro { let return_ref: bool = self.args.return_ref.is_some(); + let maybe_update_fn = quote_spanned! {output_ty.span()=> { + #[allow(clippy::all, unsafe_code)] + unsafe fn _maybe_update_fn<'db>(old_pointer: *mut #output_ty, new_value: #output_ty) -> bool { + unsafe { + use #zalsa::UpdateFallback; + #zalsa::UpdateDispatch::<#output_ty>::maybe_update( + old_pointer, new_value + ) + } + } + }}; + Ok(crate::debug::dump_tokens( fn_name, quote![salsa::plumbing::setup_tracked_fn! { @@ -137,6 +149,7 @@ impl Macro { needs_interner: #needs_interner, lru: #lru, return_ref: #return_ref, + maybe_update_fn: { #maybe_update_fn }, unused_names: [ #zalsa, #Configuration, diff --git a/components/salsa-macros/src/update.rs b/components/salsa-macros/src/update.rs index 837f146c8..c779ad6ea 100644 --- a/components/salsa-macros/src/update.rs +++ b/components/salsa-macros/src/update.rs @@ -1,4 +1,5 @@ use proc_macro2::{Literal, TokenStream}; +use syn::spanned::Spanned; use synstructure::BindStyle; use crate::hygiene::Hygiene; @@ -34,7 +35,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result .bindings() .iter() .fold(quote!(), |tokens, binding| quote!(#tokens #binding,)); - let make_new_value = quote! { + let make_new_value = quote_spanned! {variant.ast().ident.span()=> let #new_value = if let #variant_pat = #new_value { (#make_tuple) } else { @@ -46,20 +47,28 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result // For each field, invoke `maybe_update` recursively to update its value. // Or the results together (using `|`, not `||`, to avoid shortcircuiting) // to get the final return value. - let update_fields = variant.bindings().iter().zip(0..).fold( + let update_fields = variant.bindings().iter().enumerate().fold( quote!(false), - |tokens, (binding, index)| { + |tokens, (index, binding)| { let field_ty = &binding.ast().ty; let field_index = Literal::usize_unsuffixed(index); + let field_span = binding + .ast() + .ident + .as_ref() + .map(Spanned::span) + .unwrap_or(binding.ast().span()); + + let update_field = quote_spanned! {field_span=> + salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update( + #binding, + #new_value.#field_index, + ) + }; + quote! { - #tokens | - unsafe { - salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update( - #binding, - #new_value.#field_index, - ) - } + #tokens | unsafe { #update_field } } }, ); @@ -77,6 +86,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let tokens = quote! { #[allow(clippy::all)] + #[automatically_derived] unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause { unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool { use ::salsa::plumbing::UpdateFallback as _; diff --git a/examples/calc/ir.rs b/examples/calc/ir.rs index c439b4a61..891cdbd29 100644 --- a/examples/calc/ir.rs +++ b/examples/calc/ir.rs @@ -65,7 +65,7 @@ pub enum ExpressionData<'db> { Call(FunctionId<'db>, Vec>), } -#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug, salsa::Update)] +#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug)] pub enum Op { Add, Subtract, diff --git a/src/function.rs b/src/function.rs index 2bde6e58e..639b1819b 100644 --- a/src/function.rs +++ b/src/function.rs @@ -9,7 +9,7 @@ use crate::{ salsa_struct::SalsaStructInDb, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, - Cycle, Database, Id, Revision, Update, + Cycle, Database, Id, Revision, }; use self::delete::DeletedEntries; @@ -43,7 +43,7 @@ pub trait Configuration: Any { type Input<'db>: Send + Sync; /// The value computed by the function. - type Output<'db>: fmt::Debug + Send + Sync + Update; + type Output<'db>: fmt::Debug + Send + Sync; /// Determines whether this function can recover from being a participant in a cycle /// (and, if so, how). diff --git a/src/update.rs b/src/update.rs index c3e02f0be..a731e293b 100644 --- a/src/update.rs +++ b/src/update.rs @@ -1,6 +1,7 @@ use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, + marker::PhantomData, path::PathBuf, sync::Arc, }; @@ -188,6 +189,29 @@ where } } +unsafe impl Update for smallvec::SmallVec +where + A: smallvec::Array, + A::Item: Update, +{ + unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool { + let old_vec: &mut smallvec::SmallVec = unsafe { &mut *old_pointer }; + + if old_vec.len() != new_vec.len() { + old_vec.clear(); + old_vec.extend(new_vec); + return true; + } + + let mut changed = false; + for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) { + changed |= A::Item::maybe_update(old_element, new_element); + } + + changed + } +} + macro_rules! maybe_update_set { ($old_pointer: expr, $new_set: expr) => {{ let old_pointer = $old_pointer; @@ -291,6 +315,26 @@ where } } +unsafe impl Update for Box<[T]> +where + T: Update, +{ + unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool { + let old_box: &mut Box<[T]> = unsafe { &mut *old_pointer }; + + if old_box.len() == new_box.len() { + let mut changed = false; + for (old_element, new_element) in old_box.iter_mut().zip(new_box) { + changed |= T::maybe_update(old_element, new_element); + } + changed + } else { + *old_box = new_box; + true + } + } +} + unsafe impl Update for Arc where T: Update, @@ -398,6 +442,9 @@ fallback_impl! { PathBuf, } +#[cfg(feature = "compact_str")] +fallback_impl! { compact_str::CompactString, } + macro_rules! tuple_impl { ($($t:ident),*; $($u:ident),*) => { unsafe impl<$($t),*> Update for ($($t,)*) @@ -451,3 +498,9 @@ where } } } + +unsafe impl Update for PhantomData { + unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool { + false + } +} diff --git a/tests/compile-fail/tracked_fn_return_ref.rs b/tests/compile-fail/tracked_fn_return_ref.rs index 735739f1e..30708f865 100644 --- a/tests/compile-fail/tracked_fn_return_ref.rs +++ b/tests/compile-fail/tracked_fn_return_ref.rs @@ -1,5 +1,4 @@ use salsa::Database as Db; -use salsa::Update; #[salsa::input] struct MyInput { diff --git a/tests/compile-fail/tracked_fn_return_ref.stderr b/tests/compile-fail/tracked_fn_return_ref.stderr index 95412147e..a9f938a5f 100644 --- a/tests/compile-fail/tracked_fn_return_ref.stderr +++ b/tests/compile-fail/tracked_fn_return_ref.stderr @@ -1,42 +1,20 @@ -warning: unused import: `salsa::Update` - --> tests/compile-fail/tracked_fn_return_ref.rs:2:5 - | -2 | use salsa::Update; - | ^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0277]: the trait bound `&'db str: Update` is not satisfied - --> tests/compile-fail/tracked_fn_return_ref.rs:16:67 - | -16 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str { - | ^^^^^^^^ the trait `Update` is not implemented for `&'db str` +error: lifetime may not live long enough + --> tests/compile-fail/tracked_fn_return_ref.rs:14:1 | - = help: the trait `Update` is implemented for `String` -note: required by a bound in `salsa::plumbing::function::Configuration::Output` - --> src/function.rs +14 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static` +15 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str { + | - lifetime `'db` defined here | - | type Output<'db>: fmt::Debug + Send + Sync + Update; - | ^^^^^^ required by this bound in `Configuration::Output` + = note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0277]: the trait bound `ContainsRef<'db>: Update` is not satisfied - --> tests/compile-fail/tracked_fn_return_ref.rs:24:6 - | -24 | ) -> ContainsRef<'db> { - | ^^^^^^^^^^^^^^^^ the trait `Update` is not implemented for `ContainsRef<'db>` +error: lifetime may not live long enough + --> tests/compile-fail/tracked_fn_return_ref.rs:19:1 | - = help: the following other types implement trait `Update`: - () - (A, B) - (A, B, C) - (A, B, C, D) - (A, B, C, D, E) - (A, B, C, D, E, F) - (A, B, C, D, E, F, G) - (A, B, C, D, E, F, G, H) - and $N others -note: required by a bound in `salsa::plumbing::function::Configuration::Output` - --> src/function.rs +19 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static` +... +23 | ) -> ContainsRef<'db> { + | ----------- lifetime `'db` defined here | - | type Output<'db>: fmt::Debug + Send + Sync + Update; - | ^^^^^^ required by this bound in `Configuration::Output` + = note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/lru.rs b/tests/lru.rs index 634af7f2f..e8ad930db 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -8,10 +8,10 @@ use std::sync::{ mod common; use common::LogDatabase; -use salsa::{Database as _, Update}; +use salsa::Database as _; use test_log::test; -#[derive(Debug, PartialEq, Eq, Update)] +#[derive(Debug, PartialEq, Eq)] struct HotPotato(u32); thread_local! { diff --git a/tests/tracked_struct.rs b/tests/tracked_struct.rs new file mode 100644 index 000000000..a1bba36c6 --- /dev/null +++ b/tests/tracked_struct.rs @@ -0,0 +1,56 @@ +mod common; + +use salsa::{Database, Setter}; + +#[salsa::tracked] +struct Tracked<'db> { + untracked_1: usize, + + untracked_2: usize, +} + +#[salsa::input] +struct MyInput { + field1: usize, + field2: usize, +} + +#[salsa::tracked] +fn intermediate(db: &dyn salsa::Database, input: MyInput) -> Tracked<'_> { + Tracked::new(db, input.field1(db), input.field2(db)) +} + +#[salsa::tracked] +fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) { + let tracked = intermediate(db, input); + let one = read_tracked_1(db, tracked); + let two = read_tracked_2(db, tracked); + + (one, two) +} + +#[salsa::tracked] +fn read_tracked_1<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize { + tracked.untracked_1(db) +} + +#[salsa::tracked] +fn read_tracked_2<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize { + tracked.untracked_2(db) +} + +#[test] +fn execute() { + let mut db = salsa::DatabaseImpl::default(); + let input = MyInput::new(&db, 1, 1); + + assert_eq!(accumulate(&db, input), (1, 1)); + + // Should only re-execute `read_tracked_1`. + input.set_field1(&mut db).to(2); + assert_eq!(accumulate(&db, input), (2, 1)); + + // Should only re-execute `read_tracked_2`. + input.set_field2(&mut db).to(2); + assert_eq!(accumulate(&db, input), (2, 2)); +} diff --git a/tests/warnings/needless_lifetimes.rs b/tests/warnings/needless_lifetimes.rs index b41ef1f09..0eb9198d0 100644 --- a/tests/warnings/needless_lifetimes.rs +++ b/tests/warnings/needless_lifetimes.rs @@ -1,9 +1,7 @@ -use salsa::Update; - #[salsa::db] pub trait Db: salsa::Database {} -#[derive(Debug, PartialEq, Eq, Hash, Update)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Item {} #[salsa::tracked]