Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ rust-version = "1.76"

[dependencies]
arc-swap = "1"
compact_str = { version = "0.8", optional = true }
crossbeam = "0.8"
dashmap = { version = "6", features = ["raw-api"] }
hashlink = "0.9"
Expand Down
10 changes: 10 additions & 0 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ macro_rules! setup_tracked_fn {
// True if we `return_ref` flag was given to the function
return_ref: $return_ref:tt,

maybe_update_fn: {$($maybe_update_fn:tt)*},

// Annoyingly macro-rules hygiene does not extend to items defined in the macro.
// We have the procedural macro generate names for those items that are
// not used elsewhere in the user's code.
Expand Down Expand Up @@ -145,6 +147,14 @@ macro_rules! setup_tracked_fn {
}
}

/// This method isn't used anywhere. It only exitst to enforce the `Self::Output: Update` constraint
/// for types that aren't `'static`.
///
/// # Safety
/// The same safety rules as for `Update` apply.
$($maybe_update_fn)*


impl $zalsa::function::Configuration for $Configuration {
const DEBUG_NAME: &'static str = stringify!($fn_name);

Expand Down
13 changes: 13 additions & 0 deletions components/salsa-macros/src/tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ impl Macro {

let return_ref: bool = self.args.return_ref.is_some();

let maybe_update_fn = quote_spanned! {output_ty.span()=> {
#[allow(clippy::all, unsafe_code)]
unsafe fn _maybe_update_fn<'db>(old_pointer: *mut #output_ty, new_value: #output_ty) -> bool {
unsafe {
use #zalsa::UpdateFallback;
#zalsa::UpdateDispatch::<#output_ty>::maybe_update(
old_pointer, new_value
)
}
}
}};

Ok(crate::debug::dump_tokens(
fn_name,
quote![salsa::plumbing::setup_tracked_fn! {
Expand All @@ -137,6 +149,7 @@ impl Macro {
needs_interner: #needs_interner,
lru: #lru,
return_ref: #return_ref,
maybe_update_fn: { #maybe_update_fn },
unused_names: [
#zalsa,
#Configuration,
Expand Down
30 changes: 20 additions & 10 deletions components/salsa-macros/src/update.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use proc_macro2::{Literal, TokenStream};
use syn::spanned::Spanned;
use synstructure::BindStyle;

use crate::hygiene::Hygiene;
Expand Down Expand Up @@ -34,7 +35,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
.bindings()
.iter()
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
let make_new_value = quote! {
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
let #new_value = if let #variant_pat = #new_value {
(#make_tuple)
} else {
Expand All @@ -46,20 +47,28 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
// For each field, invoke `maybe_update` recursively to update its value.
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
// to get the final return value.
let update_fields = variant.bindings().iter().zip(0..).fold(
let update_fields = variant.bindings().iter().enumerate().fold(
quote!(false),
|tokens, (binding, index)| {
|tokens, (index, binding)| {
let field_ty = &binding.ast().ty;
let field_index = Literal::usize_unsuffixed(index);

let field_span = binding
.ast()
.ident
.as_ref()
.map(Spanned::span)
.unwrap_or(binding.ast().span());

let update_field = quote_spanned! {field_span=>
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
};

quote! {
#tokens |
unsafe {
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
}
#tokens | unsafe { #update_field }
}
},
);
Expand All @@ -77,6 +86,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let tokens = quote! {
#[allow(clippy::all)]
#[automatically_derived]
unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause {
unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool {
use ::salsa::plumbing::UpdateFallback as _;
Expand Down
2 changes: 1 addition & 1 deletion examples/calc/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub enum ExpressionData<'db> {
Call(FunctionId<'db>, Vec<Expression<'db>>),
}

#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug, salsa::Update)]
#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug)]
pub enum Op {
Add,
Subtract,
Expand Down
4 changes: 2 additions & 2 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
salsa_struct::SalsaStructInDb,
zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa},
zalsa_local::QueryOrigin,
Cycle, Database, Id, Revision, Update,
Cycle, Database, Id, Revision,
};

use self::delete::DeletedEntries;
Expand Down Expand Up @@ -43,7 +43,7 @@ pub trait Configuration: Any {
type Input<'db>: Send + Sync;

/// The value computed by the function.
type Output<'db>: fmt::Debug + Send + Sync + Update;
type Output<'db>: fmt::Debug + Send + Sync;

/// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how).
Expand Down
53 changes: 53 additions & 0 deletions src/update.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
hash::{BuildHasher, Hash},
marker::PhantomData,
path::PathBuf,
sync::Arc,
};
Expand Down Expand Up @@ -188,6 +189,29 @@ where
}
}

unsafe impl<A> Update for smallvec::SmallVec<A>
where
A: smallvec::Array,
A::Item: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
let old_vec: &mut smallvec::SmallVec<A> = unsafe { &mut *old_pointer };

if old_vec.len() != new_vec.len() {
old_vec.clear();
old_vec.extend(new_vec);
return true;
}

let mut changed = false;
for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) {
changed |= A::Item::maybe_update(old_element, new_element);
}

changed
}
}

macro_rules! maybe_update_set {
($old_pointer: expr, $new_set: expr) => {{
let old_pointer = $old_pointer;
Expand Down Expand Up @@ -291,6 +315,26 @@ where
}
}

unsafe impl<T> Update for Box<[T]>
where
T: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool {
let old_box: &mut Box<[T]> = unsafe { &mut *old_pointer };

if old_box.len() == new_box.len() {
let mut changed = false;
for (old_element, new_element) in old_box.iter_mut().zip(new_box) {
changed |= T::maybe_update(old_element, new_element);
}
changed
} else {
*old_box = new_box;
true
}
}
}

unsafe impl<T> Update for Arc<T>
where
T: Update,
Expand Down Expand Up @@ -398,6 +442,9 @@ fallback_impl! {
PathBuf,
}

#[cfg(feature = "compact_str")]
fallback_impl! { compact_str::CompactString, }

macro_rules! tuple_impl {
($($t:ident),*; $($u:ident),*) => {
unsafe impl<$($t),*> Update for ($($t,)*)
Expand Down Expand Up @@ -451,3 +498,9 @@ where
}
}
}

unsafe impl<T> Update for PhantomData<T> {
unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
false
}
}
1 change: 0 additions & 1 deletion tests/compile-fail/tracked_fn_return_ref.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use salsa::Database as Db;
use salsa::Update;

#[salsa::input]
struct MyInput {
Expand Down
52 changes: 15 additions & 37 deletions tests/compile-fail/tracked_fn_return_ref.stderr
Original file line number Diff line number Diff line change
@@ -1,42 +1,20 @@
warning: unused import: `salsa::Update`
--> tests/compile-fail/tracked_fn_return_ref.rs:2:5
|
2 | use salsa::Update;
| ^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

error[E0277]: the trait bound `&'db str: Update` is not satisfied
--> tests/compile-fail/tracked_fn_return_ref.rs:16:67
|
16 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str {
| ^^^^^^^^ the trait `Update` is not implemented for `&'db str`
error: lifetime may not live long enough
--> tests/compile-fail/tracked_fn_return_ref.rs:14:1
|
= help: the trait `Update` is implemented for `String`
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
--> src/function.rs
14 | #[salsa::tracked]
| ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static`
15 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str {
| - lifetime `'db` defined here
|
| type Output<'db>: fmt::Debug + Send + Sync + Update;
| ^^^^^^ required by this bound in `Configuration::Output`
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `ContainsRef<'db>: Update` is not satisfied
--> tests/compile-fail/tracked_fn_return_ref.rs:24:6
|
24 | ) -> ContainsRef<'db> {
| ^^^^^^^^^^^^^^^^ the trait `Update` is not implemented for `ContainsRef<'db>`
error: lifetime may not live long enough
--> tests/compile-fail/tracked_fn_return_ref.rs:19:1
|
= help: the following other types implement trait `Update`:
()
(A, B)
(A, B, C)
(A, B, C, D)
(A, B, C, D, E)
(A, B, C, D, E, F)
(A, B, C, D, E, F, G)
(A, B, C, D, E, F, G, H)
and $N others
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
--> src/function.rs
19 | #[salsa::tracked]
| ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static`
...
23 | ) -> ContainsRef<'db> {
| ----------- lifetime `'db` defined here
|
| type Output<'db>: fmt::Debug + Send + Sync + Update;
| ^^^^^^ required by this bound in `Configuration::Output`
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)
4 changes: 2 additions & 2 deletions tests/lru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::sync::{

mod common;
use common::LogDatabase;
use salsa::{Database as _, Update};
use salsa::Database as _;
use test_log::test;

#[derive(Debug, PartialEq, Eq, Update)]
#[derive(Debug, PartialEq, Eq)]
struct HotPotato(u32);

thread_local! {
Expand Down
56 changes: 56 additions & 0 deletions tests/tracked_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
mod common;

use salsa::{Database, Setter};

#[salsa::tracked]
struct Tracked<'db> {
untracked_1: usize,

untracked_2: usize,
}

#[salsa::input]
struct MyInput {
field1: usize,
field2: usize,
}

#[salsa::tracked]
fn intermediate(db: &dyn salsa::Database, input: MyInput) -> Tracked<'_> {
Tracked::new(db, input.field1(db), input.field2(db))
}

#[salsa::tracked]
fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) {
let tracked = intermediate(db, input);
let one = read_tracked_1(db, tracked);
let two = read_tracked_2(db, tracked);

(one, two)
}

#[salsa::tracked]
fn read_tracked_1<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
tracked.untracked_1(db)
}

#[salsa::tracked]
fn read_tracked_2<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
tracked.untracked_2(db)
}

#[test]
fn execute() {
let mut db = salsa::DatabaseImpl::default();
let input = MyInput::new(&db, 1, 1);

assert_eq!(accumulate(&db, input), (1, 1));

// Should only re-execute `read_tracked_1`.
input.set_field1(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 1));

// Should only re-execute `read_tracked_2`.
input.set_field2(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 2));
}
4 changes: 1 addition & 3 deletions tests/warnings/needless_lifetimes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use salsa::Update;

#[salsa::db]
pub trait Db: salsa::Database {}

#[derive(Debug, PartialEq, Eq, Hash, Update)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Item {}

#[salsa::tracked]
Expand Down