Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 93 additions & 47 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ where
let mut completed_query = active_query.pop();
completed_query
.revisions
.update_iteration_count_mut(database_key_index, iteration_count);
.update_cycle_participant_iteration_count(iteration_count);

claim_guard.set_release_mode(ReleaseMode::SelfOnly);
break (new_value, completed_query);
Expand All @@ -253,6 +253,10 @@ where
// Did the new result we got depend on our own provisional value, in a cycle?
// If not, return because this query is not a cycle head.
if !depends_on_self {
let Some(outer_cycle) = outer_cycle else {
panic!("cycle participant with non-empty cycle heads and that doesn't depend on itself must have an outer cycle responsible to finalize the query later (query: {database_key_index:?}, cycle heads: {cycle_heads:?}).");
};

let completed_query = complete_cycle_participant(
active_query,
claim_guard,
Expand Down Expand Up @@ -328,7 +332,6 @@ where
let value_converged = C::values_equal(&new_value, last_provisional_value);

let completed_query = match try_complete_cycle_head(
zalsa,
active_query,
claim_guard,
cycle_heads,
Expand Down Expand Up @@ -479,8 +482,8 @@ fn outer_cycle(
stack
.iter()
.find(|active_query| {
cycle_heads.contains(&active_query.database_key_index)
&& active_query.database_key_index != current_key
active_query.database_key_index != current_key
&& cycle_heads.contains(&active_query.database_key_index)
})
.map(|active_query| active_query.database_key_index)
})
Expand All @@ -503,64 +506,107 @@ fn outer_cycle(
}

/// Ensure that we resolve the latest cycle heads from any provisional value this query depended on during execution.
/// This isn't required in a single-threaded execution, but it's not guaranteed that `cycle_heads` contains all cycles
/// in a multi-threaded execution:
///
/// t1: a -> b
/// t2: c -> b (blocks on t1)
/// t1: a -> b -> c (cycle, returns fixpoint initial with c(0) in heads)
/// t1: a -> b (completes b, b has c(0) in its cycle heads, releases `b`, which resumes `t2`, and `retry_provisional` blocks on `c` (t2))
/// t2: c -> a (cycle, returns fixpoint initial for a with a(0) in heads)
/// t2: completes c, `provisional_retry` blocks on `a` (t2)
/// t1: a (completes `b` with `c` in heads)
/// ```txt
/// E -> C -> D -> B -> A -> B (cycle)
/// -- A completes, heads = [B]
/// E -> C -> D -> B -> C (cycle)
/// -> D (cycle)
/// -- B completes, heads = [B, C, D]
/// E -> C -> D -> E (cycle)
/// -- D completes, heads = [E, B, C, D]
/// E -> C
/// -- C completes, heads = [E, B, C, D]
/// E -> X -> A
/// -- X completes, heads = [B]
/// ```
///
/// Note how `a` only depends on `c` but not `a`. This is because `a` only saw the initial value of `c` and wasn't updated when `c` completed.
/// That's why we need to resolve the cycle heads recursively so `cycle_heads` contains all cycle heads at the moment this query completed.
/// Note how `X` only depends on `A`. It doesn't know that it's part of the outer cycle `X`.
/// An old implementation resolved the cycle heads 1-level deep but that's not enough, because
/// `X` then completes with `[B, C, D]` as it's heads. But `B`, `C`, and `D` are no longer on the stack
/// when `X` completes (which is the real outermost cycle). That's why we need to resolve all cycle heads
/// recursively, so that `X` completes with `[B, C, D, E]
fn collect_all_cycle_heads(
zalsa: &Zalsa,
cycle_heads: &mut CycleHeads,
database_key_index: DatabaseKeyIndex,
iteration_count: IterationCount,
) -> (IterationCount, bool) {
let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> =
SmallVec::new_const();
let mut max_iteration_count = iteration_count;
let mut depends_on_self = false;
fn collect_recursive(
zalsa: &Zalsa,
current_head: DatabaseKeyIndex,
me: DatabaseKeyIndex,
query_heads: &CycleHeads,
missing_heads: &mut SmallVec<[(DatabaseKeyIndex, IterationCount); 4]>,
) -> (IterationCount, bool) {
if current_head == me {
return (IterationCount::initial(), true);
}

for head in cycle_heads.iter() {
max_iteration_count = max_iteration_count.max(head.iteration_count.load());
depends_on_self |= head.database_key_index == database_key_index;
let mut max_iteration_count = IterationCount::initial();
let mut depends_on_self = false;

let ingredient = zalsa.lookup_ingredient(head.database_key_index.ingredient_index());
let ingredient = zalsa.lookup_ingredient(current_head.ingredient_index());

let provisional_status = ingredient
.provisional_status(zalsa, head.database_key_index.key_index())
.provisional_status(zalsa, current_head.key_index())
.expect("cycle head memo must have been created during the execution");

// A query should only ever depend on other heads that are provisional.
// If this invariant is violated, it means that this query participates in a cycle,
// but it wasn't executed in the last iteration of said cycle.
assert!(provisional_status.is_provisional());

for nested_head in provisional_status.cycle_heads() {
let nested_as_tuple = (
nested_head.database_key_index,
nested_head.iteration_count.load(),
);
for head in provisional_status.cycle_heads() {
let iteration_count = head.iteration_count.load();
max_iteration_count = max_iteration_count.max(iteration_count);

if !cycle_heads.contains(&nested_head.database_key_index)
&& !missing_heads.contains(&nested_as_tuple)
{
missing_heads.push(nested_as_tuple);
if query_heads.contains(&head.database_key_index) {
continue;
}

let head_as_tuple = (head.database_key_index, iteration_count);

if missing_heads.contains(&head_as_tuple) {
continue;
}

missing_heads.push((head.database_key_index, iteration_count));

let (nested_max_iteration_count, nested_depends_on_self) = collect_recursive(
zalsa,
head.database_key_index,
me,
query_heads,
missing_heads,
);

max_iteration_count = max_iteration_count.max(nested_max_iteration_count);
depends_on_self |= nested_depends_on_self;
}

(max_iteration_count, depends_on_self)
}

for (head_key, iteration_count) in missing_heads {
max_iteration_count = max_iteration_count.max(iteration_count);
depends_on_self |= head_key == database_key_index;
let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 4]> = SmallVec::new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the choice of 4 is ~arbitrary here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

let mut max_iteration_count = iteration_count;
let mut depends_on_self = false;

for head in &*cycle_heads {
let (recursive_max_iteration, recursive_depends_on_self) = collect_recursive(
zalsa,
head.database_key_index,
database_key_index,
cycle_heads,
&mut missing_heads,
);

cycle_heads.insert(head_key, iteration_count);
max_iteration_count = max_iteration_count.max(recursive_max_iteration);
depends_on_self |= recursive_depends_on_self;
}

for (head, iteration) in missing_heads {
cycle_heads.insert(head, iteration);
}

(max_iteration_count, depends_on_self)
Expand All @@ -570,18 +616,14 @@ fn complete_cycle_participant(
active_query: ActiveQueryGuard,
claim_guard: &mut ClaimGuard,
cycle_heads: CycleHeads,
outer_cycle: Option<DatabaseKeyIndex>,
outer_cycle: DatabaseKeyIndex,
iteration_count: IterationCount,
) -> CompletedQuery {
// For as long as this query participates in any cycle, don't release its lock, instead
// transfer it to the outermost cycle head (if any). This prevents any other thread
// transfer it to the outermost cycle head. This prevents any other thread
// from claiming this query (all cycle heads are potential entry points to the same cycle),
// which would result in them competing for the same locks (we want the locks to converge to a single cycle head).
if let Some(outer_cycle) = outer_cycle {
claim_guard.set_release_mode(ReleaseMode::TransferTo(outer_cycle));
} else {
claim_guard.set_release_mode(ReleaseMode::SelfOnly);
}
claim_guard.set_release_mode(ReleaseMode::TransferTo(outer_cycle));

let database_key_index = active_query.database_key_index;
let mut completed_query = active_query.pop();
Expand All @@ -593,9 +635,13 @@ fn complete_cycle_participant(
panic!("{database_key_index:?}: execute: too many cycle iterations")
});

// The outermost query only bumps the iteration count of cycle heads. It doesn't
// increment the iteration count for cycle participants. It's important that we bump the
// iteration count here or the head will re-use the same iteration count in the next
// iteration (which can break cache invalidation).
completed_query
.revisions
.update_iteration_count_mut(database_key_index, iteration_count);
.update_cycle_participant_iteration_count(iteration_count);

completed_query
}
Expand All @@ -604,9 +650,7 @@ fn complete_cycle_participant(
///
/// Returns `Ok` if the cycle head has converged or if it is part of an outer cycle.
/// Returns `Err` if the cycle head needs to keep iterating.
#[expect(clippy::too_many_arguments)]
fn try_complete_cycle_head(
zalsa: &Zalsa,
active_query: ActiveQueryGuard,
claim_guard: &mut ClaimGuard,
cycle_heads: CycleHeads,
Expand Down Expand Up @@ -650,6 +694,8 @@ fn try_complete_cycle_head(
return Ok(completed_query);
}

let zalsa = claim_guard.zalsa();

// If this is the outermost cycle, test if all inner cycles have converged as well.
let converged = this_converged
&& cycle_heads.iter_not_eq(me).all(|head| {
Expand Down
24 changes: 24 additions & 0 deletions src/zalsa_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,30 @@ impl QueryRevisions {
.update_iteration_count(database_key_index, iteration_count);
}

/// Updates the iteration count of the memo without updating the iteration in `cycle_heads`.
///
/// Don't call this method on a cycle head, as it results in diverging iteration counts
/// between what's in cycle heads and stored on the memo.
pub(crate) fn update_cycle_participant_iteration_count(
&mut self,
iteration_count: IterationCount,
) {
match &mut self.extra.0 {
None => {
self.extra = QueryRevisionsExtra::new(
#[cfg(feature = "accumulator")]
AccumulatedMap::default(),
ThinVec::default(),
empty_cycle_heads().clone(),
iteration_count,
);
}
Some(extra) => {
extra.iteration.store_mut(iteration_count);
}
}
}

/// Updates the iteration count if this query has any cycle heads. Otherwise it's a no-op.
pub(crate) fn update_iteration_count_mut(
&mut self,
Expand Down
98 changes: 98 additions & 0 deletions tests/cycle_stale_cycle_heads.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#![cfg(feature = "inventory")]

//! Test for stale cycle heads when nested cycles are discovered incrementally.
//!
//! Scenario from ty:
/// ```txt
/// E -> C -> D -> B -> A -> B (cycle)
/// -- A completes, heads = [B]
/// E -> C -> D -> B -> C (cycle)
/// -> D (cycle)
/// -- B completes, heads = [B, C, D]
/// E -> C -> D -> E (cycle)
/// -- D completes, heads = [E, B, C, D]
/// E -> C
/// -- C completes, heads = [E, B, C, D]
/// E -> X -> A
/// -- X completes, heads = [B]
/// ```
///
/// Note how `X` only depends on `B`, but not on `E`, unless we collect the cycle heads transitively,
/// which is what this test is asserting.

#[salsa::input]
struct Input {
value: u32,
}

// Outer cycle head - should iterate
#[salsa::tracked(cycle_initial = initial_zero)]
fn query_e(db: &dyn salsa::Database, input: Input) -> u32 {
// First call C to establish the nested cycles
let c_val = query_c(db, input);

// Then later call X which will read A with stale cycle heads
// By this point, A has already completed and memoized with cycle_heads=[B]
// But E is still on the stack
let x_val = query_x(db, input);

c_val.min(x_val)
}

#[salsa::tracked(cycle_initial = initial_zero)]
fn query_c(db: &dyn salsa::Database, input: Input) -> u32 {
query_d(db, input)
}

#[salsa::tracked(cycle_initial = initial_zero)]
fn query_d(db: &dyn salsa::Database, input: Input) -> u32 {
let b_val = query_b(db, input);

// Create cycle back to E
let e_val = query_e(db, input);

b_val.min(e_val)
}

#[salsa::tracked(cycle_initial = initial_zero)]
fn query_b(db: &dyn salsa::Database, input: Input) -> u32 {
// First call A - this will detect A<->B cycle and A will complete
let a_val = query_a(db, input);

let c_val = query_c(db, input);
let d_val = query_d(db, input);

// Then read C - this reveals B is part of C's cycle
(a_val + d_val + c_val).min(50)
}

#[salsa::tracked(cycle_initial = initial_zero)]
fn query_a(db: &dyn salsa::Database, input: Input) -> u32 {
// Read B to create A<->B cycle
let b_val = query_b(db, input);

// Also read input
let val = input.value(db);

b_val.max(val)
}

#[salsa::tracked(cycle_initial = initial_zero)]
fn query_x(db: &dyn salsa::Database, input: Input) -> u32 {
// This reads A's memoized result which has stale cycle_heads
query_a(db, input)
}

fn initial_zero(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u32 {
0
}

#[test]
fn run() {
let db = salsa::DatabaseImpl::new();
let input = Input::new(&db, 50);

let result = query_e(&db, input);

assert_eq!(result, 0);
}
6 changes: 3 additions & 3 deletions tests/parallel/cycle_nested_deep_conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ fn the_test() {
result
});
let t2 = thread::spawn(move || {
let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered();
let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered();
db_t4.wait_for(1);
query_b(&db_t4)
});
let t3 = thread::spawn(move || {
let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered();
let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered();
db_t2.wait_for(1);
query_d(&db_t2)
});
let t4 = thread::spawn(move || {
let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered();
let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered();
db_t3.wait_for(1);
query_e(&db_t3)
});
Expand Down