diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 47921f033..30442518c 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -19,9 +19,6 @@ pub enum VerifyResult { /// /// The inner value tracks whether the memo or any of its dependencies have an /// accumulated value. - /// - /// Don't mark memos verified until we've iterated the full cycle to ensure no inputs changed - /// when encountering this variant. Unchanged(InputAccumulatedValues), } @@ -37,10 +34,6 @@ impl VerifyResult { pub(crate) fn unchanged() -> Self { Self::Unchanged(InputAccumulatedValues::Empty) } - - pub(crate) const fn is_unchanged(&self) -> bool { - matches!(self, Self::Unchanged(_)) - } } impl IngredientImpl @@ -146,11 +139,11 @@ where // Check if the inputs are still valid. We can just compare `changed_at`. let deep_verify = self.deep_verify_memo(db, zalsa, old_memo, database_key_index, cycle_heads); - if deep_verify.is_unchanged() { + if let VerifyResult::Unchanged(accumulated_inputs) = deep_verify { return Some(if old_memo.revisions.changed_at > revision { VerifyResult::Changed } else { - VerifyResult::Unchanged(old_memo.revisions.accumulated_inputs.load()) + VerifyResult::Unchanged(accumulated_inputs) }); } @@ -316,18 +309,18 @@ where memo = memo.tracing_debug() ); - if memo.revisions.cycle_heads.is_empty() { + let cycle_heads = &memo.revisions.cycle_heads; + if cycle_heads.is_empty() { return true; } - let cycle_heads = &memo.revisions.cycle_heads; - zalsa_local.with_query_stack(|stack| { cycle_heads.iter().all(|cycle_head| { - stack.iter().rev().any(|query| { - query.database_key_index == cycle_head.database_key_index - && query.iteration_count() == cycle_head.iteration_count - }) + stack + .iter() + .rev() + .find(|query| query.database_key_index == cycle_head.database_key_index) + .is_some_and(|query| query.iteration_count() == cycle_head.iteration_count) }) }) } @@ -402,16 +395,18 @@ where return VerifyResult::Changed; } + let dyn_db = db.as_dyn_database(); + + let mut last_verified_at = old_memo.verified_at.load(); + let mut first_iteration = true; 'cycle: loop { + let mut inputs = InputAccumulatedValues::Empty; // Fully tracked inputs? Iterate over the inputs and check them, one by one. // // NB: It's important here that we are iterating the inputs in the order that // they executed. It's possible that if the value of some input I0 is no longer // valid, then some later input I1 might never have executed at all, so verifying // it is still up to date is meaningless. - let last_verified_at = old_memo.verified_at.load(); - let mut inputs = InputAccumulatedValues::Empty; - let dyn_db = db.as_dyn_database(); for &edge in edges.input_outputs.iter() { match edge { QueryEdge::Input(dependency_index) => { @@ -421,9 +416,7 @@ where last_verified_at, cycle_heads, ) { - VerifyResult::Changed => { - break 'cycle VerifyResult::Changed; - } + VerifyResult::Changed => break 'cycle VerifyResult::Changed, VerifyResult::Unchanged(input_accumulated) => { inputs |= input_accumulated; } @@ -477,9 +470,17 @@ where // from cycle heads. We will handle our own memo (and the rest of our cycle) on a // future iteration; first the outer cycle head needs to verify itself. - let in_heads = cycle_heads.remove(&database_key_index); + let was_in_heads = cycle_heads.remove(&database_key_index); + let heads_non_empty = !cycle_heads.is_empty(); + if heads_non_empty { + // case 2 / 4 + break 'cycle VerifyResult::Unchanged(inputs); + } else if !first_iteration { + // 3 (second loop turn) + break 'cycle VerifyResult::Unchanged(inputs); + } else { + last_verified_at = zalsa.current_revision(); - if cycle_heads.is_empty() { old_memo.mark_as_verified(zalsa, database_key_index); old_memo.revisions.accumulated_inputs.store(inputs); @@ -490,11 +491,15 @@ where .store(true, Ordering::Relaxed); } - if in_heads { + if was_in_heads { + first_iteration = false; + // case 3 continue 'cycle; + } else { + // case 1 + break 'cycle VerifyResult::Unchanged(inputs); } } - break 'cycle VerifyResult::Unchanged(inputs); } } } diff --git a/tests/cycle.rs b/tests/cycle.rs index 25702a153..55f6b97c8 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -882,7 +882,6 @@ fn cycle_unchanged() { [ "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", ]"#]]); a.assert_value(&db, 45); @@ -929,9 +928,7 @@ fn cycle_unchanged_nested() { "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(4)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", ]"#]]); a.assert_value(&db, 45); @@ -992,14 +989,12 @@ fn cycle_unchanged_nested_intertwined() { b.assert_value(&db, 60); db.assert_logs(expect![[r#" - [ - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", - "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", - ]"#]]); + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + ]"#]]); a.assert_value(&db, 45); } diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index a82aa4dfd..a61e34a0e 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -158,7 +158,6 @@ fn revalidate_no_changes() { "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(402)) })", "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })", "salsa_event(DidValidateMemoizedValue { database_key: query_a(Id(0)) })", - "salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })", ]"#]]); }