Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions benches/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type {

fn def_cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
cycle: &salsa::Cycle,
_last_provisional_value: &Type,
value: Type,
count: u32,
_def: Definition,
) -> Type {
cycle_recover(value, count)
cycle_recover(value, cycle.iteration())
}

fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type {
Expand All @@ -91,13 +90,12 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type {

fn use_cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
cycle: &salsa::Cycle,
_last_provisional_value: &Type,
value: Type,
count: u32,
_use: Use,
) -> Type {
cycle_recover(value, count)
cycle_recover(value, cycle.iteration())
}

fn cycle_recover(value: Type, count: u32) -> Type {
Expand Down
5 changes: 2 additions & 3 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,12 @@ macro_rules! setup_tracked_fn {

fn recover_from_cycle<$db_lt>(
db: &$db_lt dyn $Db,
id: salsa::Id,
cycle: &salsa::Cycle,
last_provisional_value: &Self::Output<$db_lt>,
value: Self::Output<$db_lt>,
iteration_count: u32,
($($input_id),*): ($($interned_input_ty),*)
) -> Self::Output<$db_lt> {
$($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*)
$($cycle_recovery_fn)*(db, cycle, last_provisional_value, value, $($input_id),*)
}

fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> {
Expand Down
4 changes: 2 additions & 2 deletions components/salsa-macro-rules/src/unexpected_cycle_recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// a macro because it can take a variadic number of arguments.
#[macro_export]
macro_rules! unexpected_cycle_recovery {
($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{
let (_db, _id, _last_provisional_value, _count) = ($db, $id, $last_provisional_value, $count);
($db:ident, $cycle:ident, $last_provisional_value:ident, $new_value:ident, $($other_inputs:ident),*) => {{
let (_db, _cycle, _last_provisional_value) = ($db, $cycle, $last_provisional_value);
std::mem::drop(($($other_inputs,)*));
$new_value
}};
Expand Down
48 changes: 47 additions & 1 deletion src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use thin_vec::{thin_vec, ThinVec};
use crate::key::DatabaseKeyIndex;
use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use crate::sync::OnceLock;
use crate::Revision;
use crate::{Id, Revision};

/// The maximum number of times we'll fixpoint-iterate before panicking.
///
Expand Down Expand Up @@ -238,6 +238,10 @@ impl CycleHeads {
}
}

pub(crate) fn ids(&self) -> CycleHeadIdsIterator<'_> {
CycleHeadIdsIterator { inner: self.iter() }
}

/// Iterates over all cycle heads that aren't equal to `own`.
pub(crate) fn iter_not_eq(
&self,
Expand Down Expand Up @@ -392,6 +396,7 @@ impl IntoIterator for CycleHeads {
}
}

#[derive(Clone)]
pub struct CycleHeadsIterator<'a> {
inner: std::slice::Iter<'a, CycleHead>,
}
Expand Down Expand Up @@ -448,6 +453,47 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads {
EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new()))
}

#[derive(Clone)]
pub struct CycleHeadIdsIterator<'a> {
inner: CycleHeadsIterator<'a>,
}

impl Iterator for CycleHeadIdsIterator<'_> {
type Item = crate::Id;

fn next(&mut self) -> Option<Self::Item> {
self.inner
.next()
.map(|head| head.database_key_index.key_index())
}
}

/// The context that the cycle recovery function receives when a query cycle occurs.
pub struct Cycle<'a> {
pub(crate) head_ids: CycleHeadIdsIterator<'a>,
pub(crate) id: Id,
pub(crate) iteration: u32,
}

impl Cycle<'_> {
/// An iterator that outputs the [`Id`]s of the current cycle heads.
/// This always contains the [`Id`] of the current query but it can contain additional cycle head [`Id`]s
/// if this query is nested in an outer cycle or if it has nested cycles.
pub fn head_ids(&self) -> CycleHeadIdsIterator<'_> {
self.head_ids.clone()
}

/// The [`Id`] of the query that the current cycle recovery function is processing.
pub fn id(&self) -> Id {
self.id
}

/// The counter of the current fixed point iteration.
pub fn iteration(&self) -> u32 {
self.iteration
}
}

#[derive(Debug)]
pub enum ProvisionalStatus<'db> {
Provisional {
Expand Down
5 changes: 2 additions & 3 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::table::Table;
use crate::views::DatabaseDownCaster;
use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa};
use crate::zalsa_local::{QueryEdge, QueryOriginRef};
use crate::{Id, Revision};
use crate::{Cycle, Id, Revision};

#[cfg(feature = "accumulator")]
mod accumulated;
Expand Down Expand Up @@ -124,10 +124,9 @@ pub trait Configuration: Any {
/// iterating until the returned value equals the previous iteration's value.
fn recover_from_cycle<'db>(
db: &'db Self::DbView,
id: Id,
cycle: &Cycle,
last_provisional_value: &Self::Output<'db>,
value: Self::Output<'db>,
iteration: u32,
input: Self::Input<'db>,
) -> Self::Output<'db>;

Expand Down
10 changes: 7 additions & 3 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::sync::thread;
use crate::tracked_struct::Identity;
use crate::zalsa::{MemoIngredientIndex, Zalsa};
use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions};
use crate::{tracing, Cancelled};
use crate::{tracing, Cancelled, Cycle};
use crate::{DatabaseKeyIndex, Event, EventKind, Id};

impl<C> IngredientImpl<C>
Expand Down Expand Up @@ -370,14 +370,18 @@ where
iteration_count
};

let cycle = Cycle {
head_ids: cycle_heads.ids(),
id,
iteration: iteration_count.as_u32(),
};
// We are in a cycle that hasn't converged; ask the user's
// cycle-recovery function what to do (it may return the same value or a different one):
new_value = C::recover_from_cycle(
db,
id,
&cycle,
last_provisional_value,
new_value,
iteration_count.as_u32(),
C::id_to_input(zalsa, id),
);

Expand Down
3 changes: 1 addition & 2 deletions src/function/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,9 @@ mod _memory_usage {

fn recover_from_cycle<'db>(
_: &'db Self::DbView,
_: Id,
_: &crate::Cycle,
_: &Self::Output<'db>,
value: Self::Output<'db>,
_: u32,
_: Self::Input<'db>,
) -> Self::Output<'db> {
value
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub use self::accumulator::Accumulator;
pub use self::active_query::Backtrace;
pub use self::cancelled::Cancelled;

pub use self::cycle::Cycle;
pub use self::database::Database;
pub use self::database_impl::DatabaseImpl;
pub use self::durability::Durability;
Expand Down
5 changes: 2 additions & 3 deletions tests/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ const MAX_ITERATIONS: u32 = 3;
/// returning the computed value to continue iterating.
fn cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
cycle: &salsa::Cycle,
last_provisional_value: &Value,
value: Value,
count: u32,
_inputs: Inputs,
) -> Value {
if &value == last_provisional_value {
Expand All @@ -138,7 +137,7 @@ fn cycle_recover(
.is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE)
{
Value::OutOfBounds
} else if count > MAX_ITERATIONS {
} else if cycle.iteration() > MAX_ITERATIONS {
Value::TooManyIterations
} else {
value
Expand Down
3 changes: 1 addition & 2 deletions tests/cycle_accumulate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ fn cycle_initial(_db: &dyn LogDatabase, _id: salsa::Id, _file: File) -> Vec<u32>

fn cycle_fn(
_db: &dyn LogDatabase,
_id: salsa::Id,
_cycle: &salsa::Cycle,
_last_provisional_value: &[u32],
value: Vec<u32>,
_count: u32,
_file: File,
) -> Vec<u32> {
value
Expand Down
3 changes: 1 addition & 2 deletions tests/cycle_recovery_call_back_into_cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 {

fn cycle_fn(
db: &dyn ValueDatabase,
_id: salsa::Id,
_cycle: &salsa::Cycle,
last_provisional_value: &u32,
value: u32,
_count: u32,
) -> u32 {
if &value == last_provisional_value {
value
Expand Down
3 changes: 1 addition & 2 deletions tests/cycle_recovery_call_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 {

fn cycle_fn(
db: &dyn salsa::Database,
_id: salsa::Id,
_cycle: &salsa::Cycle,
_last_provisional_value: &u32,
_value: u32,
_count: u32,
) -> u32 {
fallback_value(db)
}
Expand Down
3 changes: 1 addition & 2 deletions tests/cycle_recovery_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u3

fn cycle_fn(
db: &dyn salsa::Database,
_id: salsa::Id,
_cycle: &salsa::Cycle,
_last_provisional_value: &u32,
value: u32,
_count: u32,
input: Input,
) -> u32 {
let _input = input.value(db);
Expand Down
10 changes: 4 additions & 6 deletions tests/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,15 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type {

fn def_cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
cycle: &salsa::Cycle,
last_provisional_value: &Type,
value: Type,
count: u32,
_def: Definition,
) -> Type {
if &value == last_provisional_value {
value
} else {
cycle_recover(value, count)
cycle_recover(value, cycle.iteration())
}
}

Expand All @@ -96,16 +95,15 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type {

fn use_cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
cycle: &salsa::Cycle,
last_provisional_value: &Type,
value: Type,
count: u32,
_use: Use,
) -> Type {
if &value == last_provisional_value {
value
} else {
cycle_recover(value, count)
cycle_recover(value, cycle.iteration())
}
}

Expand Down
3 changes: 1 addition & 2 deletions tests/parallel/cycle_panic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 {

fn cycle_fn(
_db: &dyn KnobsDatabase,
_id: salsa::Id,
_cycle: &salsa::Cycle,
_last_provisional_value: &u32,
_value: u32,
_count: u32,
) -> u32 {
panic!("cancel!")
}
Expand Down
Loading