diff --git a/src/active_query.rs b/src/active_query.rs index f9b0eb9ed..71ec0bbd6 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -60,6 +60,9 @@ pub(crate) struct ActiveQuery { /// Provisional cycle results that this query depends on. cycle_heads: CycleHeads, + + /// If this query is a cycle head, iteration count of that cycle. + iteration_count: u32, } impl ActiveQuery { @@ -126,10 +129,14 @@ impl ActiveQuery { changed_at: self.changed_at, } } + + pub(super) fn iteration_count(&self) -> u32 { + self.iteration_count + } } impl ActiveQuery { - fn new(database_key_index: DatabaseKeyIndex) -> Self { + fn new(database_key_index: DatabaseKeyIndex, iteration_count: u32) -> Self { ActiveQuery { database_key_index, durability: Durability::MAX, @@ -141,6 +148,7 @@ impl ActiveQuery { accumulated: Default::default(), accumulated_inputs: Default::default(), cycle_heads: Default::default(), + iteration_count, } } @@ -156,6 +164,7 @@ impl ActiveQuery { ref mut accumulated, accumulated_inputs, ref mut cycle_heads, + iteration_count: _, } = self; let edges = QueryEdges::new(input_outputs.drain(..)); @@ -196,15 +205,17 @@ impl ActiveQuery { accumulated, accumulated_inputs: _, cycle_heads, + iteration_count, } = self; input_outputs.clear(); disambiguator_map.clear(); tracked_struct_ids.clear(); accumulated.clear(); *cycle_heads = Default::default(); + *iteration_count = 0; } - fn reset_for(&mut self, new_database_key_index: DatabaseKeyIndex) { + fn reset_for(&mut self, new_database_key_index: DatabaseKeyIndex, new_iteration_count: u32) { let Self { database_key_index, durability, @@ -216,12 +227,14 @@ impl ActiveQuery { accumulated, accumulated_inputs, cycle_heads, + iteration_count, } = self; *database_key_index = new_database_key_index; *durability = Durability::MAX; *changed_at = Revision::start(); *untracked_read = false; *accumulated_inputs = Default::default(); + *iteration_count = new_iteration_count; debug_assert!( input_outputs.is_empty(), "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" @@ -266,11 +279,16 @@ impl ops::DerefMut for QueryStack { } impl QueryStack { - pub(crate) fn push_new_query(&mut self, database_key_index: DatabaseKeyIndex) { + pub(crate) fn push_new_query( + &mut self, + database_key_index: DatabaseKeyIndex, + iteration_count: u32, + ) { if self.len < self.stack.len() { - self.stack[self.len].reset_for(database_key_index); + self.stack[self.len].reset_for(database_key_index, iteration_count); } else { - self.stack.push(ActiveQuery::new(database_key_index)); + self.stack + .push(ActiveQuery::new(database_key_index, iteration_count)); } self.len += 1; } diff --git a/src/cycle.rs b/src/cycle.rs index f8634db30..cd9fe3c9a 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -86,12 +86,19 @@ pub enum CycleRecoveryStrategy { /// A "cycle head" is the query at which we encounter a cycle; that is, if A -> B -> C -> A, then A /// would be the cycle head. It returns an "initial value" when the cycle is encountered (if /// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the -/// cycle until it converges. Any provisional value generated by any query in the cycle will track -/// the cycle head(s) (can be plural in case of nested cycles) representing the cycles it is part -/// of. This struct tracks these cycle heads. +/// cycle until it converges. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct CycleHead { + pub database_key_index: DatabaseKeyIndex, + pub iteration_count: u32, +} + +/// Any provisional value generated by any query in a cycle will track the cycle head(s) (can be +/// plural in case of nested cycles) representing the cycles it is part of, and the current +/// iteration count for each cycle head. This struct tracks these cycle heads. #[derive(Clone, Debug, Default)] #[allow(clippy::box_collection)] -pub struct CycleHeads(Option>>); +pub struct CycleHeads(Option>>); impl CycleHeads { pub(crate) fn is_empty(&self) -> bool { @@ -100,15 +107,25 @@ impl CycleHeads { self.0.is_none() } + pub(crate) fn initial(database_key_index: DatabaseKeyIndex) -> Self { + Self(Some(Box::new(vec![CycleHead { + database_key_index, + iteration_count: 0, + }]))) + } + pub(crate) fn contains(&self, value: &DatabaseKeyIndex) -> bool { - self.0.as_ref().is_some_and(|heads| heads.contains(value)) + self.into_iter() + .any(|head| head.database_key_index == *value) } pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { let Some(cycle_heads) = &mut self.0 else { return false; }; - let found = cycle_heads.iter().position(|&head| head == *value); + let found = cycle_heads + .iter() + .position(|&head| head.database_key_index == *value); let Some(found) = found else { return false }; cycle_heads.swap_remove(found); if cycle_heads.is_empty() { @@ -117,32 +134,52 @@ impl CycleHeads { true } + pub(crate) fn update_iteration_count( + &mut self, + cycle_head_index: DatabaseKeyIndex, + new_iteration_count: u32, + ) { + if let Some(cycle_head) = self.0.as_mut().and_then(|cycle_heads| { + cycle_heads + .iter_mut() + .find(|cycle_head| cycle_head.database_key_index == cycle_head_index) + }) { + cycle_head.iteration_count = new_iteration_count; + } + } + #[inline] - pub(crate) fn insert_into(self, cycle_heads: &mut Vec) { + pub(crate) fn insert_into(self, cycle_heads: &mut Vec) { if let Some(heads) = self.0 { - for head in *heads { - if !cycle_heads.contains(&head) { - cycle_heads.push(head); - } - } + insert_into_impl(&heads, cycle_heads); } } pub(crate) fn extend(&mut self, other: &Self) { if let Some(other) = &other.0 { let heads = &mut **self.0.get_or_insert_with(|| Box::new(Vec::new())); - heads.reserve(other.len()); - other.iter().for_each(|&head| { - if !heads.contains(&head) { - heads.push(head); - } - }); + insert_into_impl(other, heads); + } + } +} + +#[inline] +fn insert_into_impl(insert_from: &Vec, insert_into: &mut Vec) { + insert_into.reserve(insert_from.len()); + for head in insert_from { + if let Some(existing) = insert_into + .iter() + .find(|candidate| candidate.database_key_index == head.database_key_index) + { + assert!(existing.iteration_count == head.iteration_count); + } else { + insert_into.push(*head); } } } impl IntoIterator for CycleHeads { - type Item = DatabaseKeyIndex; + type Item = CycleHead; type IntoIter = as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { @@ -150,10 +187,10 @@ impl IntoIterator for CycleHeads { } } -pub struct CycleHeadsIter<'a>(std::slice::Iter<'a, DatabaseKeyIndex>); +pub struct CycleHeadsIter<'a>(std::slice::Iter<'a, CycleHead>); impl Iterator for CycleHeadsIter<'_> { - type Item = DatabaseKeyIndex; + type Item = CycleHead; fn next(&mut self) -> Option { self.0.next().copied() @@ -167,7 +204,7 @@ impl Iterator for CycleHeadsIter<'_> { impl std::iter::FusedIterator for CycleHeadsIter<'_> {} impl<'a> std::iter::IntoIterator for &'a CycleHeads { - type Item = DatabaseKeyIndex; + type Item = CycleHead; type IntoIter = CycleHeadsIter<'a>; fn into_iter(self) -> Self::IntoIter { @@ -180,14 +217,14 @@ impl<'a> std::iter::IntoIterator for &'a CycleHeads { } } -impl From for CycleHeads { - fn from(value: DatabaseKeyIndex) -> Self { +impl From for CycleHeads { + fn from(value: CycleHead) -> Self { Self(Some(Box::new(vec![value]))) } } -impl From> for CycleHeads { - fn from(value: Vec) -> Self { +impl From> for CycleHeads { + fn from(value: Vec) -> Self { Self(if value.is_empty() { None } else { diff --git a/src/function.rs b/src/function.rs index c989de0ef..6227257c0 100644 --- a/src/function.rs +++ b/src/function.rs @@ -243,7 +243,11 @@ where fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { let zalsa = db.zalsa(); self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) - .is_some_and(|memo| memo.cycle_heads().contains(&self.database_key_index(input))) + .is_some_and(|memo| { + memo.cycle_heads() + .into_iter() + .any(|head| head.database_key_index == self.database_key_index(input)) + }) } /// Attempts to claim `key_index`, returning `false` if a cycle occurs. diff --git a/src/function/execute.rs b/src/function/execute.rs index d13839d21..c8c2e3237 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -123,6 +123,9 @@ where if iteration_count > MAX_ITERATIONS { panic!("{database_key_index:?}: execute: too many cycle iterations"); } + revisions + .cycle_heads + .update_iteration_count(database_key_index, iteration_count); opt_last_provisional = Some(self.insert_memo( zalsa, id, @@ -130,7 +133,9 @@ where memo_ingredient_index, )); - active_query = db.zalsa_local().push_query(database_key_index); + active_query = db + .zalsa_local() + .push_query(database_key_index, iteration_count); continue; } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index edc12d3d3..6df5b0372 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -83,7 +83,8 @@ where if let Some(memo) = memo_guard { let database_key_index = self.database_key_index(id); if memo.value.is_some() - && self.validate_may_be_provisional(db, zalsa, database_key_index, memo) + && (self.validate_may_be_provisional(db, zalsa, database_key_index, memo) + || self.validate_same_iteration(db, database_key_index, memo)) && self.shallow_verify_memo(db, zalsa, database_key_index, memo) { // SAFETY: memo is present in memo_map and we have verified that it is @@ -158,7 +159,7 @@ where }; // Push the query on the stack. - let active_query = db.zalsa_local().push_query(database_key_index); + let active_query = db.zalsa_local().push_query(database_key_index, 0); // Now that we've claimed the item, check again to see if there's a "hot" value. let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index ae4dadab8..d0b89ff0c 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -112,7 +112,7 @@ where CycleRecoveryStrategy::Fixpoint => { return Some(VerifyResult::Unchanged( InputAccumulatedValues::Empty, - CycleHeads::from(database_key_index), + CycleHeads::initial(database_key_index), )); } }, @@ -131,7 +131,7 @@ where ); // Check if the inputs are still valid. We can just compare `changed_at`. - let active_query = db.zalsa_local().push_query(database_key_index); + let active_query = db.zalsa_local().push_query(database_key_index, 0); if let VerifyResult::Unchanged(_, cycle_heads) = self.deep_verify_memo(db, zalsa, old_memo, &active_query) { @@ -243,14 +243,17 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { - tracing::debug!( + tracing::trace!( "{database_key_index:?}: validate_provisional(memo = {memo:#?})", memo = memo.tracing_debug() ); if (&memo.revisions.cycle_heads).into_iter().any(|cycle_head| { zalsa - .lookup_ingredient(cycle_head.ingredient_index()) - .is_provisional_cycle_head(db.as_dyn_database(), cycle_head.key_index()) + .lookup_ingredient(cycle_head.database_key_index.ingredient_index()) + .is_provisional_cycle_head( + db.as_dyn_database(), + cycle_head.database_key_index.key_index(), + ) }) { return false; } @@ -260,6 +263,33 @@ where true } + /// If this is a provisional memo, validate that it was cached in the same iteration of the + /// same cycle(s) that we are still executing. If so, it is valid for reuse. This avoids + /// runaway re-execution of the same queries within a fixpoint iteration. + pub(super) fn validate_same_iteration( + &self, + db: &C::DbView, + database_key_index: DatabaseKeyIndex, + memo: &Memo>, + ) -> bool { + tracing::trace!( + "{database_key_index:?}: validate_same_iteration(memo = {memo:#?})", + memo = memo.tracing_debug() + ); + for cycle_head in &memo.revisions.cycle_heads { + if !db.zalsa_local().with_query_stack(|stack| { + stack.iter().rev().any(|entry| { + entry.database_key_index == cycle_head.database_key_index + && entry.iteration_count() == cycle_head.iteration_count + }) + }) { + return false; + } + } + + true + } + /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the /// current revision. When this returns Unchanged with no cycle heads, it also updates the /// memo's `verified_at` field if needed to make future calls cheaper. @@ -390,7 +420,7 @@ where let in_heads = cycle_heads .iter() - .position(|&head| head == database_key_index) + .position(|&head| head.database_key_index == database_key_index) .inspect(|&head| _ = cycle_heads.swap_remove(head)) .is_some(); diff --git a/src/function/memo.rs b/src/function/memo.rs index 00b1440b3..5e28d9265 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -173,15 +173,16 @@ impl Memo { let hit_cycle = self .cycle_heads() .into_iter() - .filter(|&head| head != database_key_index) + .filter(|&head| head.database_key_index != database_key_index) .any(|head| { - let ingredient = zalsa.lookup_ingredient(head.ingredient_index()); - if !ingredient.is_provisional_cycle_head(db, head.key_index()) { + let head_index = head.database_key_index; + let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index()); + if !ingredient.is_provisional_cycle_head(db, head_index.key_index()) { // This cycle is already finalized, so we don't need to wait on it; // keep looping through cycle heads. retry = true; false - } else if ingredient.wait_for(db, head.key_index()) { + } else if ingredient.wait_for(db, head_index.key_index()) { // There's a new memo available for the cycle head; fetch our own // updated memo and see if it's still provisional or if the cycle // has resolved. diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 596ac7787..c93d1efa1 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -87,9 +87,13 @@ impl ZalsaLocal { } #[inline] - pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { + pub(crate) fn push_query( + &self, + database_key_index: DatabaseKeyIndex, + iteration_count: u32, + ) -> ActiveQueryGuard<'_> { let mut query_stack = self.query_stack.borrow_mut(); - query_stack.push_new_query(database_key_index); + query_stack.push_new_query(database_key_index, iteration_count); ActiveQueryGuard { local_state: self, database_key_index, @@ -338,7 +342,7 @@ impl QueryRevisions { accumulated: Default::default(), accumulated_inputs: Default::default(), verified_final: AtomicBool::new(false), - cycle_heads: CycleHeads::from(query), + cycle_heads: CycleHeads::initial(query), } } diff --git a/tests/cycle.rs b/tests/cycle.rs index 23b81a633..596c51256 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -984,7 +984,7 @@ fn cycle_unchanged_nested_intertwined() { e.assert_value(&db, 60); } - db.assert_logs_len(16 + i); + db.assert_logs_len(12 + i); // next revision, we change only A, which is not part of the cycle and the cycle does not // depend on. @@ -1004,3 +1004,34 @@ fn cycle_unchanged_nested_intertwined() { a.assert_value(&db, 45); } } + +/// Provisional query results in a cycle should still be cached within a single iteration. +/// +/// a:Ni(v59, b) -> b:Np(v60, c, c, c) -> c:Np(a) +/// ^ | +/// +------------------------------------------+ +#[test] +fn repeat_provisional_query() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db) + .to(vec![value(60), c.clone(), c.clone(), c]); + c_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 59); + + db.assert_logs(expect![[r#" + [ + "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", + ]"#]]); +} diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 1d867f900..c2862ada1 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -180,7 +180,6 @@ fn accumulate_with_cycle_second_revision() { [ "check_file(name = file_b, issues = [2, 3])", "check_file(name = file_a, issues = [1])", - "check_file(name = file_b, issues = [2, 3])", "check_file(name = file_a, issues = [1])", "check_file(name = file_b, issues = [2, 3])", ]"#]]); diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index e98106284..1908ffc03 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -169,7 +169,6 @@ fn revalidate_with_change_after_output_read() { "salsa_event(WillExecute { database_key: read_value(Id(400)) })", "salsa_event(WillExecute { database_key: query_b(Id(0)) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: read_value(Id(401)) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: read_value(Id(400)) })",