diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 63cc4fe12..05e06c0d0 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -48,6 +48,7 @@ impl CalcDatabaseImpl { } #[cfg(test)] + #[allow(unused)] pub fn take_logs(&self) -> Vec { let mut logs = self.logs.lock().unwrap(); if let Some(logs) = &mut *logs { diff --git a/src/active_query.rs b/src/active_query.rs index cc5e4fc58..00d0f5338 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -5,8 +5,6 @@ use crate::accumulator::{ accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues}, Accumulator, }; -use crate::cycle::{CycleHeads, IterationCount}; -use crate::durability::Durability; use crate::hash::FxIndexSet; use crate::key::DatabaseKeyIndex; use crate::runtime::Stamp; @@ -14,6 +12,11 @@ use crate::sync::atomic::AtomicBool; use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap}; use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra}; use crate::Revision; +use crate::{ + cycle::{CycleHeads, IterationCount}, + Id, +}; +use crate::{durability::Durability, tracked_struct::Identity}; #[derive(Debug)] pub(crate) struct ActiveQuery { @@ -74,6 +77,7 @@ impl ActiveQuery { changed_at: Revision, edges: &[QueryEdge], untracked_read: bool, + active_tracked_ids: &[(Identity, Id)], ) { assert!(self.input_outputs.is_empty()); @@ -83,7 +87,8 @@ impl ActiveQuery { self.untracked_read |= untracked_read; // Mark all tracked structs from the previous iteration as active. - self.tracked_struct_ids.mark_all_active(); + self.tracked_struct_ids + .mark_all_active(active_tracked_ids.iter().copied()); } pub(super) fn add_read( @@ -408,7 +413,7 @@ pub(crate) struct CompletedQuery { /// The keys of any tracked structs that were created in a previous execution of the /// query but not the current one, and should be marked as stale. - pub(crate) stale_tracked_structs: Vec, + pub(crate) stale_tracked_structs: Vec<(Identity, Id)>, } struct CapturedQuery { diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 923a0fc88..003310ae1 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -27,8 +27,9 @@ where // Note that tracked structs are not stored as direct query outputs, but they are still outputs // that need to be reported as stale. - for output in &completed_query.stale_tracked_structs { - Self::report_stale_output(zalsa, key, *output); + for (identity, id) in &completed_query.stale_tracked_structs { + let output = DatabaseKeyIndex::new(identity.ingredient_index(), *id); + Self::report_stale_output(zalsa, key, output); } let mut stale_outputs = output_edges(edges).collect::>(); diff --git a/src/function/execute.rs b/src/function/execute.rs index 67cee969d..d1651859e 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -3,6 +3,7 @@ use crate::cycle::{CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::ActiveQueryGuard; use crate::{Event, EventKind, Id}; @@ -134,13 +135,25 @@ where let database_key_index = active_query.database_key_index; let mut iteration_count = IterationCount::initial(); let mut fell_back = false; + let zalsa_local = db.zalsa_local(); // Our provisional value from the previous iteration, when doing fixpoint iteration. // Initially it's set to None, because the initial provisional value is created lazily, // only when a cycle is actually encountered. let mut opt_last_provisional: Option<&Memo<'db, C>> = None; + let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); + loop { let previous_memo = opt_last_provisional.or(opt_old_memo); + + // Tracked struct ids that existed in the previous revision + // but weren't recreated in the last iteration. It's important that we seed the next + // query with these ids because the query might re-create them as part of the next iteration. + // This is not only important to ensure that the re-created tracked structs have the same ids, + // it's also important to ensure that these tracked structs get removed + // if they aren't recreated when reaching the final iteration. + active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); + let (mut new_value, mut completed_query) = Self::execute_query(db, zalsa, active_query, previous_memo, id); @@ -239,10 +252,9 @@ where ), memo_ingredient_index, )); + last_stale_tracked_ids = completed_query.stale_tracked_structs; - active_query = db - .zalsa_local() - .push_query(database_key_index, iteration_count); + active_query = zalsa_local.push_query(database_key_index, iteration_count); continue; } @@ -280,9 +292,7 @@ where if let Some(old_memo) = opt_old_memo { // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. - if let Some(tracked_struct_ids) = old_memo.revisions.tracked_struct_ids() { - active_query.seed_tracked_struct_ids(tracked_struct_ids); - } + active_query.seed_tracked_struct_ids(old_memo.revisions.tracked_struct_ids()); // Copy over all inputs and outputs from a previous iteration. // This is necessary to: diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 54fce885d..e7c6597f6 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -161,11 +161,12 @@ where old_memo = old_memo.tracing_debug() ); + let zalsa_local = db.zalsa_local(); let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); if can_shallow_update.yes() && self.validate_may_be_provisional( zalsa, - db.zalsa_local(), + zalsa_local, database_key_index, old_memo, // Don't conclude that the query is unchanged if the memo itself is still @@ -506,7 +507,7 @@ where old_memo = old_memo.tracing_debug() ); - debug_assert!(!cycle_heads.contains(database_key_index)); + assert!(!cycle_heads.contains(database_key_index)); match old_memo.revisions.origin.as_ref() { QueryOriginRef::Derived(edges) => { diff --git a/src/function/memo.rs b/src/function/memo.rs index 4894cc642..9671c83d1 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -326,7 +326,7 @@ where stale_output.remove_stale_output(zalsa, executor); } - for (identity, id) in self.revisions.tracked_struct_ids().into_iter().flatten() { + for (identity, id) in self.revisions.tracked_struct_ids() { let key = DatabaseKeyIndex::new(identity.ingredient_index(), *id); key.remove_stale_output(zalsa, executor); } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 7ef998a4b..cccb13fd1 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -255,19 +255,15 @@ pub(crate) struct IdentityMap { impl IdentityMap { /// Seeds the identity map with the IDs from a previous revision. pub(crate) fn seed(&mut self, source: &[(Identity, Id)]) { - self.table.clear(); - self.table - .reserve(source.len(), |entry| entry.identity.hash); - for &(key, id) in source { self.insert_entry(key, id, false); } } // Mark all tracked structs in the map as created by the current query. - pub(crate) fn mark_all_active(&mut self) { - for entry in self.table.iter_mut() { - entry.active = true; + pub(crate) fn mark_all_active(&mut self, items: impl IntoIterator) { + for (key, id) in items { + self.insert_entry(key, id, true); } } @@ -330,7 +326,8 @@ impl IdentityMap { /// The first entry contains the identity and IDs of any tracked structs that were /// created by the current execution of the query, while the second entry contains any /// tracked structs that were created in a previous execution but not the current one. - pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec) { + #[expect(clippy::type_complexity)] + pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec<(Identity, Id)>) { if self.table.is_empty() { return (ThinVec::new(), Vec::new()); } @@ -342,19 +339,14 @@ impl IdentityMap { if entry.active { active.push((entry.identity, entry.id)); } else { - stale.push(DatabaseKeyIndex::new( - entry.identity.ingredient_index(), - entry.id, - )); + stale.push((entry.identity, entry.id)); } } // Removing a stale tracked struct ID shows up in the event logs, so make sure // the order is stable here. stale.sort_unstable_by(|a, b| { - a.ingredient_index() - .cmp(&b.ingredient_index()) - .then(a.key_index().cmp(&b.key_index())) + (a.0.ingredient_index(), a.1).cmp(&(b.0.ingredient_index(), b.1)) }); (active, stale) diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index ced5e9281..77387f72f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -668,13 +668,13 @@ impl QueryRevisions { } } - /// Returns a reference to the `IdentityMap` for this query, or `None` if the map is empty. - pub fn tracked_struct_ids(&self) -> Option<&[(Identity, Id)]> { + /// Returns the ids of the tracked structs created when running this query. + pub fn tracked_struct_ids(&self) -> &[(Identity, Id)] { self.extra .0 .as_ref() .map(|extra| &*extra.tracked_struct_ids) - .filter(|tracked_struct_ids| !tracked_struct_ids.is_empty()) + .unwrap_or_default() } /// Returns a mutable reference to the `IdentityMap` for this query, or `None` if the map is empty. @@ -1090,7 +1090,6 @@ impl ActiveQueryGuard<'_> { #[cfg(debug_assertions)] assert_eq!(stack.len(), self.push_len); let frame = stack.last_mut().unwrap(); - assert!(frame.tracked_struct_ids().is_empty()); frame.tracked_struct_ids_mut().seed(tracked_struct_ids); }) } @@ -1105,6 +1104,7 @@ impl ActiveQueryGuard<'_> { previous.origin.as_ref(), QueryOriginRef::DerivedUntracked(_) ); + let tracked_ids = previous.tracked_struct_ids(); // SAFETY: We do not access the query stack reentrantly. unsafe { @@ -1112,7 +1112,7 @@ impl ActiveQueryGuard<'_> { #[cfg(debug_assertions)] assert_eq!(stack.len(), self.push_len); let frame = stack.last_mut().unwrap(); - frame.seed_iteration(durability, changed_at, edges, untracked_read); + frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids); }) } } diff --git a/tests/cycle.rs b/tests/cycle.rs index 2eb9bac23..f2a3334c3 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -1129,11 +1129,12 @@ fn repeat_provisional_query_incremental() { // `validate_same_iteration` incorrectly returns `false`. db.assert_logs(expect![[r#" [ - "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", - "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1), fell_back: false })", "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", ]"#]]); } + diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index b9ef6ed14..c8a4cd451 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -1,8 +1,5 @@ #![cfg(feature = "inventory")] -//! Tests for cycles where the cycle head is stored on a tracked struct -//! and that tracked struct is freed in a later revision. - mod common; use crate::common::{EventLoggerDatabase, LogDatabase}; @@ -45,6 +42,7 @@ struct Node<'db> { #[salsa::input(debug)] struct GraphInput { simple: bool, + fixpoint_variant: usize, } #[salsa::tracked(returns(ref))] @@ -125,11 +123,13 @@ fn cycle_recover( CycleRecoveryAction::Iterate } +/// Tests for cycles where the cycle head is stored on a tracked struct +/// and that tracked struct is freed in a later revision. #[test] fn main() { let mut db = EventLoggerDatabase::default(); - let input = GraphInput::new(&db, false); + let input = GraphInput::new(&db, false, 0); let graph = create_graph(&db, input); let c = graph.find_node(&db, "c").unwrap(); @@ -192,3 +192,250 @@ fn main() { "WillCheckCancellation", ]"#]]); } + +#[salsa::tracked] +struct IterationNode<'db> { + #[returns(ref)] + name: String, + iteration: usize, +} + +/// A cyclic query that creates more tracked structs in later fixpoint iterations. +/// +/// The output depends on the input's fixpoint_variant: +/// - variant=0: Returns `[base]` (1 struct, no cycle) +/// - variant=1: Through fixpoint iteration, returns `[iter_0, iter_1, iter_2]` (3 structs) +/// - variant=2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs) +/// - variant>2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs, same as variant=2) +/// +/// When variant > 0, the query creates a cycle by calling itself. The fixpoint iteration +/// proceeds as follows: +/// 1. Initial: returns empty vector +/// 2. First iteration: returns `[iter_0]` +/// 3. Second iteration: returns `[iter_0, iter_1]` +/// 4. Third iteration (only for variant=1): returns `[iter_0, iter_1, iter_2]` +/// 5. Further iterations: no change, fixpoint reached +#[salsa::tracked(cycle_fn=cycle_recover_with_structs, cycle_initial=initial_with_structs)] +fn create_tracked_in_cycle<'db>( + db: &'db dyn Database, + input: GraphInput, +) -> Vec> { + // Check if we should create more nodes based on the input. + let variant = input.fixpoint_variant(db); + + if variant == 0 { + // Base case - no cycle, just return a single node. + vec![IterationNode::new(db, "base".to_string(), 0)] + } else { + // Create a cycle by calling ourselves. + let previous = create_tracked_in_cycle(db, input); + + // In later iterations, create additional tracked structs. + if previous.is_empty() { + // First iteration - initial returns empty. + vec![IterationNode::new(db, "iter_0".to_string(), 0)] + } else { + // Limit based on variant: variant=1 allows 3 nodes, variant=2 allows 2 nodes. + let limit = if variant == 1 { 3 } else { 2 }; + + if previous.len() < limit { + // Subsequent iterations - add more nodes. + let mut nodes = previous; + nodes.push(IterationNode::new( + db, + format!("iter_{}", nodes.len()), + nodes.len(), + )); + nodes + } else { + // Reached the limit. + previous + } + } + } +} + +fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec> { + vec![] +} + +fn cycle_recover_with_structs<'db>( + _db: &'db dyn Database, + _value: &Vec>, + _iteration: u32, + _input: GraphInput, +) -> CycleRecoveryAction>> { + CycleRecoveryAction::Iterate +} + +#[test] +fn test_cycle_with_fixpoint_structs() { + let mut db = EventLoggerDatabase::default(); + + // Create an input that will trigger the cyclic behavior. + let input = GraphInput::new(&db, false, 1); + + // Initial query - this will create structs across multiple iterations. + let nodes = create_tracked_in_cycle(&db, input); + assert_eq!(nodes.len(), 3); + // First iteration: previous is empty [], so we get [iter_0] + // Second iteration: previous is [iter_0], so we get [iter_0, iter_1] + // Third iteration: previous is [iter_0, iter_1], so we get [iter_0, iter_1, iter_2] + assert_eq!(nodes[0].name(&db), "iter_0"); + assert_eq!(nodes[1].name(&db), "iter_1"); + assert_eq!(nodes[2].name(&db), "iter_2"); + + // Clear logs to focus on the change. + db.clear_logs(); + + // Change the input to force re-execution with a different variant. + // This will create 2 tracked structs instead of 3 (one fewer than before). + input.set_fixpoint_variant(&mut db).to(2); + + // Re-query - this should handle the tracked struct changes properly. + let nodes = create_tracked_in_cycle(&db, input); + assert_eq!(nodes.len(), 2); + assert_eq!(nodes[0].name(&db), "iter_0"); + assert_eq!(nodes[1].name(&db), "iter_1"); + + // Check the logs to ensure proper execution and struct management. + // We should see the third struct (iter_2) being discarded. + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillExecute { database_key: create_tracked_in_cycle(Id(0)) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1), fell_back: false }", + "WillCheckCancellation", + "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2), fell_back: false }", + "WillCheckCancellation", + "WillDiscardStaleOutput { execute_key: create_tracked_in_cycle(Id(0)), output_key: IterationNode(Id(402)) }", + "DidDiscard { key: IterationNode(Id(402)) }", + ]"#]]); +} + +// Additional test structures for the new scenario +#[salsa::tracked] +struct TrackedValue<'db> { + value: u32, +} + +#[salsa::input] +struct InputValue { + value: u32, +} + +#[salsa::input] +struct IterationCounter { + count: std::sync::Arc, +} + +#[salsa::tracked] +fn query_c<'db>(db: &'db dyn Database, tracked: TrackedValue<'db>) -> u32 { + tracked.value(db) +} + +#[salsa::tracked(cycle_fn=cycle_recover_b, cycle_initial=initial_b)] +fn query_b<'db>(db: &'db dyn Database, input: InputValue) -> u32 { + // Call query_a to create the cycle + let a_result = query_a(db, input); + + // Only create tracked struct when a_result reaches a certain threshold + // This creates an internal condition for when to create the tracked struct + if a_result <= 50 { + let tracked = TrackedValue::new(db, 42); + let c_result = query_c(db, tracked); + c_result + } else { + a_result - 10 // Reduce by 10 to force iteration + } +} + +fn initial_b(_db: &dyn Database, _input: InputValue) -> u32 { + u32::MAX +} + +fn cycle_recover_b( + _db: &dyn Database, + _value: &u32, + _count: u32, + _input: InputValue, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +#[salsa::tracked(cycle_fn=cycle_recover_a, cycle_initial=initial_a)] +fn query_a<'db>(db: &'db dyn Database, input: InputValue) -> u32 { + let input_val = input.value(db); + // Call query_b to create the cycle + let b_result = query_b(db, input); + b_result.min(input_val) +} + +fn initial_a(_db: &dyn Database, _input: InputValue) -> u32 { + u32::MAX +} + +fn cycle_recover_a( + _db: &dyn Database, + _value: &u32, + _count: u32, + _input: InputValue, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +/// Test scenario with tracked struct created during cycle iteration. +/// +/// a -> b -> a (cycle) +/// -> c(tracked_struct) +/// +/// - a is the cycle head +/// - b participates in the cycle and creates a tracked struct based on internal condition +/// - The tracked struct is created when a_result <= 50 (internal condition, not explicit counter) +/// - When input changes, a must rerun and should panic due to tracked struct cleanup issue +#[test] +#[should_panic(expected = "cannot delete read-locked id")] +fn cycle_with_tracked_struct_creation_during_iteration() { + let mut db = EventLoggerDatabase::default(); + + // Set up inputs + let input = InputValue::new(&db, 50); + + // Execute query_a which triggers the cycle + let result = query_a(&db, input); + + // First iteration: a returns input (50), b returns a_result - 10 = 40 + // Second iteration: a returns min(50, 40) = 40, b sees a_result <= 50, creates tracked struct + // Result should be the tracked struct value (42) + assert_eq!(result, 42); + + // Clear logs for the next part + db.clear_logs(); + + // Change the input to force recomputation + input.set_value(&mut db).to(30); + + // Re-execute, should see appropriate logs + let result2 = query_a(&db, input); + + // New result should be 42 (tracked struct value), but this should panic first + assert_eq!(result2, 42); + + // Verify we see the expected execution pattern in logs + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillExecute { database_key: query_a(Id(0)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: query_b(Id(0)) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query_a(Id(0)), iteration_count: IterationCount(1), fell_back: false }", + "WillCheckCancellation", + "WillExecute { database_key: query_b(Id(0)) }", + "WillCheckCancellation", + ]"#]]); +}