diff --git a/src/accumulator.rs b/src/accumulator.rs index a841016dd..9a476c807 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -110,6 +110,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Id, _revision: Revision, + _in_cycle: bool, ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/active_query.rs b/src/active_query.rs index baca24e69..2535803cc 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -66,6 +66,20 @@ pub(crate) struct ActiveQuery { } impl ActiveQuery { + pub(super) fn seed_iteration( + &mut self, + durability: Durability, + changed_at: Revision, + edges: &[QueryEdge], + untracked_read: bool, + ) { + assert!(self.input_outputs.is_empty()); + self.input_outputs = edges.iter().cloned().collect(); + self.durability = self.durability.min(durability); + self.changed_at = self.changed_at.max(changed_at); + self.untracked_read |= untracked_read; + } + pub(super) fn add_read( &mut self, input: DatabaseKeyIndex, diff --git a/src/function.rs b/src/function.rs index 555bc726c..98662d26d 100644 --- a/src/function.rs +++ b/src/function.rs @@ -232,10 +232,11 @@ where db: &dyn Database, input: Id, revision: Revision, + in_cycle: bool, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient as per caller invariant let db = unsafe { self.view_caster.downcast_unchecked(db) }; - self.maybe_changed_after(db, input, revision) + self.maybe_changed_after(db, input, revision, in_cycle) } /// True if the input `input` contains a memo that cites itself as a cycle head. @@ -285,7 +286,6 @@ where _db: &dyn Database, _executor: DatabaseKeyIndex, _stale_output_key: crate::Id, - _provisional: bool, ) { // This function is invoked when a query Q specifies the value for `stale_output_key` in rev 1, // but not in rev 2. We don't do anything in this case, we just leave the (now stale) memo. diff --git a/src/function/backdate.rs b/src/function/backdate.rs index 685a329ec..fc8d2005a 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -1,6 +1,7 @@ use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::zalsa_local::QueryRevisions; +use crate::DatabaseKeyIndex; impl IngredientImpl where @@ -12,6 +13,7 @@ where pub(super) fn backdate_if_appropriate<'db>( &self, old_memo: &Memo>, + index: DatabaseKeyIndex, revisions: &mut QueryRevisions, value: &C::Output<'db>, ) { @@ -24,7 +26,7 @@ where && C::values_equal(old_value, value) { tracing::debug!( - "value is equal, back-dating to {:?}", + "{index:?} value is equal, back-dating to {:?}", old_memo.revisions.changed_at, ); diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index f0d9efaeb..6bc47b5fd 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -24,7 +24,6 @@ where key: DatabaseKeyIndex, old_memo: &Memo>, revisions: &mut QueryRevisions, - provisional: bool, ) { // Iterate over the outputs of the `old_memo` and put them into a hashset let mut old_outputs: FxIndexSet<_> = old_memo.revisions.origin.outputs().collect(); @@ -50,7 +49,7 @@ where }); for old_output in old_outputs { - Self::report_stale_output(zalsa, db, key, old_output, provisional); + Self::report_stale_output(zalsa, db, key, old_output); } } @@ -59,7 +58,6 @@ where db: &C::DbView, key: DatabaseKeyIndex, output: DatabaseKeyIndex, - provisional: bool, ) { db.salsa_event(&|| { Event::new(EventKind::WillDiscardStaleOutput { @@ -67,6 +65,6 @@ where output_key: output, }) }); - output.remove_stale_output(zalsa, db.as_dyn_database(), key, provisional); + output.remove_stale_output(zalsa, db.as_dyn_database(), key); } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 13e695019..bc72634fd 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -101,19 +101,11 @@ where // really change, even if some of its inputs have. So we can // "backdate" its `changed_at` revision to be the same as the // old value. - self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); + self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &new_value); // Diff the new outputs with the old, to discard any no-longer-emitted // outputs and update the tracked struct IDs for seeding the next revision. - let provisional = !revisions.cycle_heads.is_empty(); - self.diff_outputs( - zalsa, - db, - database_key_index, - old_memo, - &mut revisions, - provisional, - ); + self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); } self.insert_memo( zalsa, @@ -142,8 +134,14 @@ where // only when a cycle is actually encountered. let mut opt_last_provisional: Option<&Memo<::Output<'db>>> = None; loop { - let (mut new_value, mut revisions) = - Self::execute_query(db, active_query, opt_old_memo, zalsa.current_revision(), id); + let previous_memo = opt_last_provisional.or(opt_old_memo); + let (mut new_value, mut revisions) = Self::execute_query( + db, + active_query, + previous_memo, + zalsa.current_revision(), + id, + ); // Did the new result we got depend on our own provisional value, in a cycle? if revisions.cycle_heads.contains(&database_key_index) { @@ -255,27 +253,25 @@ where current_revision: Revision, id: Id, ) -> (C::Output<'db>, QueryRevisions) { - // If we already executed this query once, then use the tracked-struct ids from the - // previous execution as the starting point for the new one. if let Some(old_memo) = opt_old_memo { + // If we already executed this query once, then use the tracked-struct ids from the + // previous execution as the starting point for the new one. active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); + + // Copy over all inputs and outputs from a previous iteration. + // This is necessary to: + // * ensure that tracked struct created during the previous iteration + // (and are owned by the query) are alive even if the query in this iteration no longer creates them. + // * ensure the final returned memo depends on all inputs from all iterations. + if old_memo.may_be_provisional() && old_memo.verified_at.load() == current_revision { + active_query.seed_iteration(&old_memo.revisions); + } } // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! let new_value = C::execute(db, C::id_to_input(db, id)); - if let Some(old_memo) = opt_old_memo { - // Copy over all outputs from a previous iteration. - // This is necessary to ensure that tracked struct created during the previous iteration - // (and are owned by the query) are alive even if the query in this iteration no longer creates them. - // The query not re-creating the tracked struct doesn't guarantee that there - // aren't any other queries depending on it. - if old_memo.may_be_provisional() && old_memo.verified_at.load() == current_revision { - active_query.append_outputs(old_memo.revisions.origin.outputs()); - } - } - (new_value, active_query.pop()) } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 4a8d90755..cd1fa804f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -138,14 +138,10 @@ where "hit cycle at {database_key_index:#?}, \ inserting and returning fixpoint initial value" ); - let revisions = QueryRevisions::fixpoint_initial( - database_key_index, - zalsa.current_revision(), - ); - let initial_value = self.initial_value(db, id).expect( - "`CycleRecoveryStrategy::Fixpoint` \ - should have initial_value", - ); + let revisions = QueryRevisions::fixpoint_initial(database_key_index); + let initial_value = self + .initial_value(db, id) + .expect("`CycleRecoveryStrategy::Fixpoint` should have initial_value"); Some(self.insert_memo( zalsa, id, @@ -159,8 +155,7 @@ where ); let active_query = db.zalsa_local().push_query(database_key_index, 0); let fallback_value = self.initial_value(db, id).expect( - "`CycleRecoveryStrategy::FallbackImmediate` \ - should have initial_value", + "`CycleRecoveryStrategy::FallbackImmediate` should have initial_value", ); let mut revisions = active_query.pop(); revisions.cycle_heads = CycleHeads::initial(database_key_index); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index b2102ff53..7c513f6df 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -13,7 +13,9 @@ use crate::{AsDynDatabase as _, Id, Revision}; /// Result of memo validation. pub enum VerifyResult { /// Memo has changed and needs to be recomputed. - Changed, + /// + /// The cycle heads encountered when validating the memo. + Changed(CycleHeads), /// Memo remains valid. /// @@ -28,15 +30,37 @@ pub enum VerifyResult { impl VerifyResult { pub(crate) fn changed_if(changed: bool) -> Self { if changed { - Self::Changed + Self::changed() } else { Self::unchanged() } } + pub(crate) fn changed() -> Self { + Self::Changed(CycleHeads::default()) + } + pub(crate) fn unchanged() -> Self { Self::Unchanged(InputAccumulatedValues::Empty, CycleHeads::default()) } + + pub(crate) fn cycle_heads(&self) -> &CycleHeads { + match self { + Self::Changed(cycle_heads) => cycle_heads, + Self::Unchanged(_, cycle_heads) => cycle_heads, + } + } + + pub(crate) fn into_cycle_heads(self) -> CycleHeads { + match self { + Self::Changed(cycle_heads) => cycle_heads, + Self::Unchanged(_, cycle_heads) => cycle_heads, + } + } + + pub(crate) const fn is_unchanged(&self) -> bool { + matches!(self, Self::Unchanged(_, _)) + } } impl IngredientImpl @@ -48,6 +72,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, + in_cycle: bool, ) -> VerifyResult { let zalsa = db.zalsa(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); @@ -62,7 +87,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); let Some(memo) = memo_guard else { // No memo? Assume has changed. - return VerifyResult::Changed; + return VerifyResult::changed(); }; if let Some(shallow_update) = self.shallow_verify_memo(zalsa, database_key_index, memo) @@ -71,7 +96,7 @@ where self.update_shallow(db, zalsa, database_key_index, memo, shallow_update); return if memo.revisions.changed_at > revision { - VerifyResult::Changed + VerifyResult::changed() } else { VerifyResult::Unchanged( memo.revisions.accumulated_inputs.load(), @@ -81,9 +106,14 @@ where } } - if let Some(mcs) = - self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index) - { + if let Some(mcs) = self.maybe_changed_after_cold( + zalsa, + db, + id, + revision, + memo_ingredient_index, + in_cycle, + ) { return mcs; } else { // We failed to claim, have to retry. @@ -99,6 +129,7 @@ where key_index: Id, revision: Revision, memo_ingredient_index: MemoIngredientIndex, + in_cycle: bool, ) -> Option { let database_key_index = self.database_key_index(key_index); @@ -116,6 +147,9 @@ where return Some(VerifyResult::unchanged()); } CycleRecoveryStrategy::Fixpoint => { + tracing::debug!( + "hit cycle at {database_key_index:?} in `maybe_changed_after`, returning fixpoint initial value", + ); return Some(VerifyResult::Unchanged( InputAccumulatedValues::Empty, CycleHeads::initial(database_key_index), @@ -127,7 +161,7 @@ where // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) else { - return Some(VerifyResult::Changed); + return Some(VerifyResult::changed()); }; tracing::debug!( @@ -137,13 +171,15 @@ where ); // Check if the inputs are still valid. We can just compare `changed_at`. - if let VerifyResult::Unchanged(_, cycle_heads) = - self.deep_verify_memo(db, zalsa, old_memo, database_key_index) - { + let deep_verify = self.deep_verify_memo(db, zalsa, old_memo, database_key_index); + if deep_verify.is_unchanged() { return Some(if old_memo.revisions.changed_at > revision { - VerifyResult::Changed + VerifyResult::Changed(deep_verify.into_cycle_heads()) } else { - VerifyResult::Unchanged(old_memo.revisions.accumulated_inputs.load(), cycle_heads) + VerifyResult::Unchanged( + old_memo.revisions.accumulated_inputs.load(), + deep_verify.into_cycle_heads(), + ) }); } @@ -151,13 +187,19 @@ where // It is possible the result will be equal to the old value and hence // backdated. In that case, although we will have computed a new memo, // the value has not logically changed. - if old_memo.value.is_some() { + // However, executing the query here is only safe if we are not in a cycle. + // In a cycle, it's important that the cycle head gets executed or we + // risk that some dependencies of this query haven't been verified yet because + // the cycle head returned *fixpoint initial* without validating its dependencies. + // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks + // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). + if old_memo.value.is_some() && !in_cycle && deep_verify.cycle_heads().is_empty() { let active_query = db.zalsa_local().push_query(database_key_index, 0); let memo = self.execute(db, active_query, Some(old_memo)); let changed_at = memo.revisions.changed_at; return Some(if changed_at > revision { - VerifyResult::Changed + VerifyResult::changed() } else { VerifyResult::Unchanged( match &memo.revisions.accumulated { @@ -170,7 +212,7 @@ where } // Otherwise, nothing for it: have to consider the value to have changed. - Some(VerifyResult::Changed) + Some(VerifyResult::Changed(deep_verify.into_cycle_heads())) } /// `Some` if the memo's value and `changed_at` time is still valid in this revision. @@ -369,20 +411,33 @@ where // Conditionally specified queries // where the value is specified // in rev 1 but not in rev 2. - VerifyResult::Changed + VerifyResult::changed() + } + QueryOrigin::FixpointInitial => { + let is_provisional = old_memo.may_be_provisional(); + + // If the value is from the same revision but is still provisional, consider it changed + // because we're now in a new iteration. + if shallow_update_possible && is_provisional { + return VerifyResult::Changed(CycleHeads::initial(database_key_index)); + } + + VerifyResult::Unchanged( + InputAccumulatedValues::Empty, + CycleHeads::initial(database_key_index), + ) } - QueryOrigin::FixpointInitial if old_memo.may_be_provisional() => VerifyResult::Changed, - QueryOrigin::FixpointInitial => VerifyResult::unchanged(), QueryOrigin::DerivedUntracked(_) => { // Untracked inputs? Have to assume that it changed. - VerifyResult::Changed + VerifyResult::changed() } QueryOrigin::Derived(edges) => { let is_provisional = old_memo.may_be_provisional(); // If the value is from the same revision but is still provisional, consider it changed + // because we're now in a new iteration. if shallow_update_possible && is_provisional { - return VerifyResult::Changed; + return VerifyResult::changed(); } let mut cycle_heads = CycleHeads::default(); @@ -399,9 +454,14 @@ where for &edge in edges.input_outputs.iter() { match edge { QueryEdge::Input(dependency_index) => { - match dependency_index.maybe_changed_after(dyn_db, last_verified_at) - { - VerifyResult::Changed => break 'cycle VerifyResult::Changed, + match dependency_index.maybe_changed_after( + dyn_db, + last_verified_at, + !cycle_heads.is_empty(), + ) { + VerifyResult::Changed(_) => { + break 'cycle VerifyResult::Changed(cycle_heads) + } VerifyResult::Unchanged(input_accumulated, cycles) => { cycle_heads.extend(&cycles); inputs |= input_accumulated; diff --git a/src/function/specify.rs b/src/function/specify.rs index 4d5c7169a..f87b01acb 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -76,15 +76,8 @@ where let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { - self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs( - zalsa, - db, - database_key_index, - old_memo, - &mut revisions, - false, - ); + self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &value); + self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); } let memo = Memo { diff --git a/src/ingredient.rs b/src/ingredient.rs index c50a5f57f..756af3b0e 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -63,6 +63,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &'db dyn Database, input: Id, revision: Revision, + in_cycle: bool, ) -> VerifyResult; /// Is the value for `input` in this ingredient a cycle head that is still provisional? @@ -106,9 +107,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &dyn Database, executor: DatabaseKeyIndex, stale_output_key: Id, - provisional: bool, ) { - let _ = (db, executor, stale_output_key, provisional); + let _ = (db, executor, stale_output_key); unreachable!("only tracked struct ingredients can have stale outputs") } diff --git a/src/input.rs b/src/input.rs index 4aab953fb..ef1c1ca6d 100644 --- a/src/input.rs +++ b/src/input.rs @@ -213,6 +213,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Id, _revision: Revision, + _in_cycle: bool, ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 0fa94a100..85d8bdf9b 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -54,6 +54,7 @@ where db: &dyn Database, input: Id, revision: Revision, + _in_cycle: bool, ) -> VerifyResult { let zalsa = db.zalsa(); let value = >::data(zalsa, input); diff --git a/src/interned.rs b/src/interned.rs index 3fdacc7bc..07716a055 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -350,7 +350,8 @@ where assert!( internal_data.last_interned_at.load() >= last_changed_revision, - "Data was not interned in the latest revision for its durability." + "Data {:?} was not interned in the latest revision for its durability.", + self.database_key_index(id) ); unsafe { Self::from_internal_data(&internal_data.fields) } @@ -394,12 +395,13 @@ where db: &dyn Database, input: Id, revision: Revision, + _in_cycle: bool, ) -> VerifyResult { let zalsa = db.zalsa(); let value = zalsa.table().get::>(input); if value.first_interned_at > revision { // The slot was reused. - return VerifyResult::Changed; + return VerifyResult::changed(); } // The slot is valid in this revision but we have to sync the value's revision. diff --git a/src/key.rs b/src/key.rs index d2e8ed1f1..c3fa57ada 100644 --- a/src/key.rs +++ b/src/key.rs @@ -37,12 +37,13 @@ impl DatabaseKeyIndex { &self, db: &dyn Database, last_verified_at: crate::Revision, + in_cycle: bool, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { db.zalsa() .lookup_ingredient(self.ingredient_index) - .maybe_changed_after(db, self.key_index, last_verified_at) + .maybe_changed_after(db, self.key_index, last_verified_at, in_cycle) } } @@ -51,11 +52,10 @@ impl DatabaseKeyIndex { zalsa: &Zalsa, db: &dyn Database, executor: DatabaseKeyIndex, - provisional: bool, ) { zalsa .lookup_ingredient(self.ingredient_index) - .remove_stale_output(db, executor, self.key_index, provisional) + .remove_stale_output(db, executor, self.key_index) } pub(crate) fn mark_validated_output( diff --git a/src/runtime.rs b/src/runtime.rs index 36587428b..2ea77133c 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -126,6 +126,7 @@ impl Runtime { } pub(crate) fn set_cancellation_flag(&self) { + tracing::trace!("set_cancellation_flag"); self.revision_canceled.store(true, Ordering::Release); } @@ -151,6 +152,7 @@ impl Runtime { let r_old = self.current_revision(); let r_new = r_old.next(); self.revisions[0] = r_new; + tracing::debug!("new_revision: {r_old:?} -> {r_new:?}"); r_new } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 1b09c1087..03cd17d67 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -405,7 +405,7 @@ where // This is a new tracked struct, so create an entry in the struct map. let id = self.allocate(zalsa, zalsa_local, current_revision, ¤t_deps, fields); let key = self.database_key_index(id); - tracing::trace!("Allocated new tracked struct {id:?}", id = key); + tracing::trace!("Allocated new tracked struct {key:?}"); zalsa_local.add_output(key); zalsa_local.store_tracked_struct_id(identity, id); FromId::from_id(id) @@ -581,7 +581,7 @@ where /// Using this method on an entity id that MAY be used in the current revision will lead to /// unspecified results (but not UB). See [`InternedIngredient::delete_index`] for more /// discussion and important considerations. - pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id, provisional: bool) { + pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id) { db.salsa_event(&|| { Event::new(crate::EventKind::DidDiscard { key: self.database_key_index(id), @@ -599,7 +599,7 @@ where None => { panic!("cannot delete write-locked id `{id:?}`; value leaked across threads"); } - Some(r) if !provisional && r == current_revision => panic!( + Some(r) if r == current_revision => panic!( "cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic" ), Some(r) => { @@ -637,7 +637,7 @@ where db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(zalsa, db, executor, provisional); + stale_output.remove_stale_output(zalsa, db, executor); } }) }; @@ -739,6 +739,7 @@ where db: &dyn Database, input: Id, revision: Revision, + _in_cycle: bool, ) -> VerifyResult { let zalsa = db.zalsa(); let data = Self::data(zalsa.table(), input); @@ -766,13 +767,12 @@ where db: &dyn Database, _executor: DatabaseKeyIndex, stale_output_key: crate::Id, - provisional: bool, ) { // This method is called when, in prior revisions, // `executor` creates a tracked struct `salsa_output_key`, // but it did not in the current revision. // In that case, we can delete `stale_output_key` and any data associated with it. - self.delete_entity(db, stale_output_key, provisional); + self.delete_entity(db, stale_output_key); } fn debug_name(&self) -> &'static str { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 04619469b..1f6a79bef 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -58,6 +58,7 @@ where db: &'db dyn Database, input: Id, revision: crate::Revision, + _in_cycle: bool, ) -> VerifyResult { let zalsa = db.zalsa(); let data = >::data(zalsa.table(), input); diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 322c771c6..1cf239788 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -370,9 +370,9 @@ pub(crate) struct QueryRevisions { } impl QueryRevisions { - pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, revision: Revision) -> Self { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { Self { - changed_at: revision, + changed_at: Revision::start(), durability: Durability::MAX, origin: QueryOrigin::FixpointInitial, tracked_struct_ids: Default::default(), @@ -424,6 +424,16 @@ impl QueryOrigin { }; opt_edges.into_iter().flat_map(|edges| edges.outputs()) } + + pub(crate) fn edges(&self) -> &[QueryEdge] { + let opt_edges = match self { + QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), + QueryOrigin::Assigned(_) | QueryOrigin::FixpointInitial => None, + }; + opt_edges + .map(|edges| &*edges.input_outputs) + .unwrap_or_default() + } } /// The edges between a memoized value and other queries in the dependency graph. @@ -508,18 +518,17 @@ impl ActiveQueryGuard<'_> { } /// Append the given `outputs` to the query's output list. - pub(crate) fn append_outputs(&self, outputs: I) - where - I: IntoIterator + UnwindSafe, - { + pub(crate) fn seed_iteration(&self, previous: &QueryRevisions) { + let durability = previous.durability; + let changed_at = previous.changed_at; + let edges = previous.origin.edges(); + let untracked_read = matches!(previous.origin, QueryOrigin::DerivedUntracked(_)); + self.local_state.with_query_stack_mut(|stack| { #[cfg(debug_assertions)] assert_eq!(stack.len(), self.push_len); let frame = stack.last_mut().unwrap(); - - for output in outputs { - frame.add_output(output); - } + frame.seed_iteration(durability, changed_at, edges, untracked_read); }) } diff --git a/tests/cycle_maybe_changed_after.rs b/tests/cycle_maybe_changed_after.rs new file mode 100644 index 000000000..cfc271bbc --- /dev/null +++ b/tests/cycle_maybe_changed_after.rs @@ -0,0 +1,213 @@ +//! Tests for incremental validation for queries involved in a cycle. +mod common; + +use crate::common::EventLoggerDatabase; +use salsa::{CycleRecoveryAction, Database, Durability, Setter}; + +#[salsa::input(debug)] +struct Input { + value: u32, + max: u32, +} + +#[salsa::interned(debug)] +struct Output<'db> { + #[return_ref] + value: u32, +} + +#[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)] +fn query_c<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { + query_d(db, input) +} + +#[salsa::tracked] +fn query_d<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { + let value = query_c(db, input); + if value < input.max(db) * 2 { + // Only the first iteration depends on value but the entire + // cycle must re-run if input changes. + let result = value + input.value(db); + Output::new(db, result); + result + } else { + value + } +} + +fn query_a_initial(_db: &dyn Database, _input: Input) -> u32 { + 0 +} + +fn query_a_recover( + _db: &dyn Database, + _output: &u32, + _count: u32, + _input: Input, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +/// Only the first iteration depends on `input.value`. It's important that the entire query +/// reruns if `input.value` changes. That's why salsa has to carry-over the inputs and outputs +/// from the previous iteration. +#[test_log::test] +fn first_iteration_input_only() { + #[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)] + fn query_a<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { + query_b(db, input) + } + + #[salsa::tracked] + fn query_b<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { + let value = query_a(db, input); + + if value < input.max(db) { + // Only the first iteration depends on value but the entire + // cycle must re-run if input changes. + value + input.value(db) + } else { + value + } + } + + let mut db = EventLoggerDatabase::default(); + + let input = Input::builder(4, 5).durability(Durability::MEDIUM).new(&db); + + { + let result = query_a(&db, input); + + assert_eq!(result, 8); + } + + { + input.set_value(&mut db).to(3); + + let result = query_a(&db, input); + assert_eq!(result, 6); + } +} + +/// Very similar to the previous test, but the difference is that the called function +/// isn't the cycle head and that `cycle_participant` is called from +/// both the `cycle_head` and the `entry` function. +#[test_log::test] +fn nested_cycle_fewer_dependencies_in_first_iteration() { + #[salsa::interned(debug)] + struct ClassLiteral<'db> { + scope: Scope<'db>, + } + + #[salsa::tracked] + impl<'db> ClassLiteral<'db> { + #[salsa::tracked] + fn context(self, db: &'db dyn salsa::Database) -> u32 { + let scope = self.scope(db); + + // Access a field on `scope` that changed in the new revision. + scope.field(db) + } + } + + #[salsa::tracked(debug)] + struct Scope<'db> { + field: u32, + } + + #[salsa::tracked] + fn create_interned<'db>(db: &'db dyn salsa::Database, scope: Scope<'db>) -> ClassLiteral<'db> { + ClassLiteral::new(db, scope) + } + + #[derive(Eq, PartialEq, Debug, salsa::Update)] + struct Index<'db> { + scope: Scope<'db>, + } + + #[salsa::tracked(cycle_fn=head_recover, cycle_initial=head_initial)] + fn cycle_head<'db>(db: &'db dyn salsa::Database, input: Input) -> Option> { + let b = cycle_outer(db, input); + tracing::info!("query_b = {b:?}"); + + b.or_else(|| { + let index = index(db, input); + Some(create_interned(db, index.scope)) + }) + } + + fn head_initial(_db: &dyn Database, _input: Input) -> Option> { + None + } + + fn head_recover<'db>( + _db: &'db dyn Database, + _output: &Option>, + _count: u32, + _input: Input, + ) -> CycleRecoveryAction>> { + CycleRecoveryAction::Iterate + } + + #[salsa::tracked] + fn cycle_outer<'db>(db: &'db dyn salsa::Database, input: Input) -> Option> { + cycle_participant(db, input) + } + + #[salsa::tracked] + fn cycle_participant<'db>( + db: &'db dyn salsa::Database, + input: Input, + ) -> Option> { + let value = cycle_head(db, input); + tracing::info!("cycle_head = {value:?}"); + + if let Some(value) = value { + value.context(db); + Some(value) + } else { + None + } + } + + #[salsa::tracked(return_ref)] + fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Index<'db> { + Index { + scope: Scope::new(db, input.value(db) * 2), + } + } + + #[salsa::tracked] + fn entry(db: &dyn salsa::Database, input: Input) -> u32 { + let _ = input.value(db); + let head = cycle_head(db, input); + + let participant = cycle_participant(db, input); + tracing::debug!("head: {head:?}, participant: {participant:?}"); + + head.or(participant) + .map(|class| class.scope(db).field(db)) + .unwrap_or(0) + } + + let mut db = EventLoggerDatabase::default(); + + let input = Input::builder(3, 5) + .max_durability(Durability::HIGH) + .value_durability(Durability::LOW) + .new(&db); + + { + let result = entry(&db, input); + + assert_eq!(result, 6); + } + + db.synthetic_write(Durability::MEDIUM); + + { + input.set_value(&mut db).to(4); + let result = entry(&db, input); + assert_eq!(result, 8); + } +} diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 66b9d566c..7c5d77ddc 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -140,9 +140,12 @@ fn revalidate_no_changes() { db.assert_logs(expect![[r#" [ "salsa_event(DidSetCancellationFlag)", - "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })", + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(400)) })", "salsa_event(DidReinternValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", "salsa_event(DidValidateMemoizedValue { database_key: query_d(Id(800)) })", + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(401)) })", + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(402)) })", + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })", "salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })", "salsa_event(DidReinternValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", "salsa_event(DidValidateMemoizedValue { database_key: query_a(Id(0)) })", @@ -170,29 +173,29 @@ fn revalidate_with_change_after_output_read() { db.assert_logs(expect![[r#" [ "salsa_event(DidSetCancellationFlag)", - "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })", + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(400)) })", + "salsa_event(DidReinternValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", "salsa_event(DidReinternValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", - "salsa_event(WillExecute { database_key: query_d(Id(800)) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(400)) })", + "salsa_event(WillExecute { database_key: query_d(Id(800)) })", "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(403)) })", "salsa_event(DidDiscard { key: Output(Id(403)) })", "salsa_event(DidDiscard { key: read_value(Id(403)) })", - "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(402)) })", - "salsa_event(DidDiscard { key: Output(Id(402)) })", - "salsa_event(DidDiscard { key: read_value(Id(402)) })", "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(401)) })", "salsa_event(DidDiscard { key: Output(Id(401)) })", "salsa_event(DidDiscard { key: read_value(Id(401)) })", - "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(402)) })", + "salsa_event(DidDiscard { key: Output(Id(402)) })", + "salsa_event(DidDiscard { key: read_value(Id(402)) })", "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: 1, fell_back: false })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: read_value(Id(403)) })", "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: 2, fell_back: false })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(WillExecute { database_key: read_value(Id(402)) })", + "salsa_event(WillExecute { database_key: read_value(Id(401)) })", "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: 3, fell_back: false })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(WillExecute { database_key: read_value(Id(401)) })", + "salsa_event(WillExecute { database_key: read_value(Id(402)) })", ]"#]]); }