diff --git a/benches/dataflow.rs b/benches/dataflow.rs index cf20140f6..4d18a2532 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -76,13 +76,12 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, _last_provisional_value: &Type, value: Type, - count: u32, _def: Definition, ) -> Type { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { @@ -91,13 +90,12 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, _last_provisional_value: &Type, value: Type, - count: u32, _use: Use, ) -> Type { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } fn cycle_recover(value: Type, count: u32) -> Type { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 1c3312372..9cb311fc5 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -308,13 +308,12 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Self::Output<$db_lt>, value: Self::Output<$db_lt>, - iteration_count: u32, ($($input_id),*): ($($interned_input_ty),*) ) -> Self::Output<$db_lt> { - $($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*) + $($cycle_recovery_fn)*(db, cycle, last_provisional_value, value, $($input_id),*) } fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index fe002fa4e..e22875311 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,8 +3,8 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{ - let (_db, _id, _last_provisional_value, _count) = ($db, $id, $last_provisional_value, $count); + ($db:ident, $cycle:ident, $last_provisional_value:ident, $new_value:ident, $($other_inputs:ident),*) => {{ + let (_db, _cycle, _last_provisional_value) = ($db, $cycle, $last_provisional_value); std::mem::drop(($($other_inputs,)*)); $new_value }}; diff --git a/src/cycle.rs b/src/cycle.rs index 0f12472b4..8ab5dabcd 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -50,7 +50,7 @@ use thin_vec::{thin_vec, ThinVec}; use crate::key::DatabaseKeyIndex; use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use crate::sync::OnceLock; -use crate::Revision; +use crate::{Id, Revision}; /// The maximum number of times we'll fixpoint-iterate before panicking. /// @@ -238,6 +238,10 @@ impl CycleHeads { } } + pub(crate) fn ids(&self) -> CycleHeadIdsIterator<'_> { + CycleHeadIdsIterator { inner: self.iter() } + } + /// Iterates over all cycle heads that aren't equal to `own`. pub(crate) fn iter_not_eq( &self, @@ -392,6 +396,7 @@ impl IntoIterator for CycleHeads { } } +#[derive(Clone)] pub struct CycleHeadsIterator<'a> { inner: std::slice::Iter<'a, CycleHead>, } @@ -448,6 +453,47 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new())) } +#[derive(Clone)] +pub struct CycleHeadIdsIterator<'a> { + inner: CycleHeadsIterator<'a>, +} + +impl Iterator for CycleHeadIdsIterator<'_> { + type Item = crate::Id; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|head| head.database_key_index.key_index()) + } +} + +/// The context that the cycle recovery function receives when a query cycle occurs. +pub struct Cycle<'a> { + pub(crate) head_ids: CycleHeadIdsIterator<'a>, + pub(crate) id: Id, + pub(crate) iteration: u32, +} + +impl Cycle<'_> { + /// An iterator that outputs the [`Id`]s of the current cycle heads. + /// This always contains the [`Id`] of the current query but it can contain additional cycle head [`Id`]s + /// if this query is nested in an outer cycle or if it has nested cycles. + pub fn head_ids(&self) -> CycleHeadIdsIterator<'_> { + self.head_ids.clone() + } + + /// The [`Id`] of the query that the current cycle recovery function is processing. + pub fn id(&self) -> Id { + self.id + } + + /// The counter of the current fixed point iteration. + pub fn iteration(&self) -> u32 { + self.iteration + } +} + #[derive(Debug)] pub enum ProvisionalStatus<'db> { Provisional { diff --git a/src/function.rs b/src/function.rs index b9878bc41..f7f302727 100644 --- a/src/function.rs +++ b/src/function.rs @@ -21,7 +21,7 @@ use crate::table::Table; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryEdge, QueryOriginRef}; -use crate::{Id, Revision}; +use crate::{Cycle, Id, Revision}; #[cfg(feature = "accumulator")] mod accumulated; @@ -124,10 +124,9 @@ pub trait Configuration: Any { /// iterating until the returned value equals the previous iteration's value. fn recover_from_cycle<'db>( db: &'db Self::DbView, - id: Id, + cycle: &Cycle, last_provisional_value: &Self::Output<'db>, value: Self::Output<'db>, - iteration: u32, input: Self::Input<'db>, ) -> Self::Output<'db>; diff --git a/src/function/execute.rs b/src/function/execute.rs index 558ace738..b0b8b8609 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -12,7 +12,7 @@ use crate::sync::thread; use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; -use crate::{tracing, Cancelled}; +use crate::{tracing, Cancelled, Cycle}; use crate::{DatabaseKeyIndex, Event, EventKind, Id}; impl IngredientImpl @@ -370,14 +370,18 @@ where iteration_count }; + let cycle = Cycle { + head_ids: cycle_heads.ids(), + id, + iteration: iteration_count.as_u32(), + }; // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do (it may return the same value or a different one): new_value = C::recover_from_cycle( db, - id, + &cycle, last_provisional_value, new_value, - iteration_count.as_u32(), C::id_to_input(zalsa, id), ); diff --git a/src/function/memo.rs b/src/function/memo.rs index f22af65fe..234829cb1 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -562,10 +562,9 @@ mod _memory_usage { fn recover_from_cycle<'db>( _: &'db Self::DbView, - _: Id, + _: &crate::Cycle, _: &Self::Output<'db>, value: Self::Output<'db>, - _: u32, _: Self::Input<'db>, ) -> Self::Output<'db> { value diff --git a/src/lib.rs b/src/lib.rs index d4409c4a9..f90fce338 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; +pub use self::cycle::Cycle; pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; diff --git a/tests/cycle.rs b/tests/cycle.rs index dd476ab76..ba407226e 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -125,10 +125,9 @@ const MAX_ITERATIONS: u32 = 3; /// returning the computed value to continue iterating. fn cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Value, value: Value, - count: u32, _inputs: Inputs, ) -> Value { if &value == last_provisional_value { @@ -138,7 +137,7 @@ fn cycle_recover( .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) { Value::OutOfBounds - } else if count > MAX_ITERATIONS { + } else if cycle.iteration() > MAX_ITERATIONS { Value::TooManyIterations } else { value diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 6377805b8..63325ec13 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -50,10 +50,9 @@ fn cycle_initial(_db: &dyn LogDatabase, _id: salsa::Id, _file: File) -> Vec fn cycle_fn( _db: &dyn LogDatabase, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &[u32], value: Vec, - _count: u32, _file: File, ) -> Vec { value diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 77f7378e4..c0bbf00c1 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -27,10 +27,9 @@ fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 { fn cycle_fn( db: &dyn ValueDatabase, - _id: salsa::Id, + _cycle: &salsa::Cycle, last_provisional_value: &u32, value: u32, - _count: u32, ) -> u32 { if &value == last_provisional_value { value diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index dae4203d7..57c1915ab 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -23,10 +23,9 @@ fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 { fn cycle_fn( db: &dyn salsa::Database, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &u32, _value: u32, - _count: u32, ) -> u32 { fallback_value(db) } diff --git a/tests/cycle_recovery_dependencies.rs b/tests/cycle_recovery_dependencies.rs index fe93428e5..fd9a5f956 100644 --- a/tests/cycle_recovery_dependencies.rs +++ b/tests/cycle_recovery_dependencies.rs @@ -37,10 +37,9 @@ fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u3 fn cycle_fn( db: &dyn salsa::Database, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &u32, value: u32, - _count: u32, input: Input, ) -> u32 { let _input = input.value(db); diff --git a/tests/dataflow.rs b/tests/dataflow.rs index f91123ef0..a0d50834f 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -77,16 +77,15 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Type, value: Type, - count: u32, _def: Definition, ) -> Type { if &value == last_provisional_value { value } else { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } } @@ -96,16 +95,15 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, - _id: salsa::Id, + cycle: &salsa::Cycle, last_provisional_value: &Type, value: Type, - count: u32, _use: Use, ) -> Type { if &value == last_provisional_value { value } else { - cycle_recover(value, count) + cycle_recover(value, cycle.iteration()) } } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index 34cbb7ed2..4afc375d5 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -20,10 +20,9 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 { fn cycle_fn( _db: &dyn KnobsDatabase, - _id: salsa::Id, + _cycle: &salsa::Cycle, _last_provisional_value: &u32, _value: u32, - _count: u32, ) -> u32 { panic!("cancel!") }