diff --git a/src/accumulator/accumulated_map.rs b/src/accumulator/accumulated_map.rs index f922de9f3..61032622d 100644 --- a/src/accumulator/accumulated_map.rs +++ b/src/accumulator/accumulated_map.rs @@ -74,6 +74,14 @@ impl InputAccumulatedValues { pub const fn is_empty(self) -> bool { matches!(self, Self::Empty) } + + pub fn or_else(self, other: impl FnOnce() -> Self) -> Self { + if self.is_any() { + Self::Any + } else { + other() + } + } } impl ops::BitOr for InputAccumulatedValues { diff --git a/src/active_query.rs b/src/active_query.rs index 71ec0bbd6..55d3bed73 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -66,18 +66,23 @@ pub(crate) struct ActiveQuery { } impl ActiveQuery { + #[inline] pub(super) fn add_read( &mut self, input: DatabaseKeyIndex, durability: Durability, - revision: Revision, - accumulated: InputAccumulatedValues, + changed_at: Revision, + has_accumulated: bool, + accumulated_inputs: &AtomicInputAccumulatedValues, cycle_heads: &CycleHeads, ) { self.durability = self.durability.min(durability); - self.changed_at = self.changed_at.max(revision); + self.changed_at = self.changed_at.max(changed_at); self.input_outputs.insert(QueryEdge::Input(input)); - self.accumulated_inputs |= accumulated; + self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated { + true => InputAccumulatedValues::Any, + false => accumulated_inputs.load(), + }); self.cycle_heads.extend(cycle_heads); } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 2d5a373b9..fe311d5e1 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,5 @@ -use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl, VerifyResult}; -use crate::runtime::StampedValue; use crate::table::sync::ClaimResult; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions, ZalsaLocal}; @@ -18,26 +16,19 @@ where let memo = self.refresh_memo(db, id); // SAFETY: We just refreshed the memo so it is guaranteed to contain a value now. let memo_value = unsafe { memo.value.as_ref().unwrap_unchecked() }; - let StampedValue { - value, - durability, - changed_at, - } = memo.revisions.stamped_value(memo_value); self.lru.record_use(id); zalsa_local.report_tracked_read( self.database_key_index(id), - durability, - changed_at, - match &memo.revisions.accumulated { - Some(_) => InputAccumulatedValues::Any, - None => memo.revisions.accumulated_inputs.load(), - }, + memo.revisions.durability, + memo.revisions.changed_at, + memo.revisions.accumulated.is_some(), + &memo.revisions.accumulated_inputs, memo.cycle_heads(), ); - value + memo_value } #[inline] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index d5a7d2fd9..6421961d3 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -228,13 +228,7 @@ where update: ShallowUpdate, ) { if let ShallowUpdate::HigherDurability(revision_now) = update { - memo.mark_as_verified( - db, - revision_now, - database_key_index, - memo.revisions.accumulated_inputs.load(), - ); - + memo.mark_as_verified(db, revision_now, database_key_index); memo.mark_outputs_as_verified(zalsa, db.as_dyn_database(), database_key_index); } } @@ -457,12 +451,8 @@ where let in_heads = cycle_heads.remove(&database_key_index); if cycle_heads.is_empty() { - old_memo.mark_as_verified( - db, - zalsa.current_revision(), - database_key_index, - inputs, - ); + old_memo.mark_as_verified(db, zalsa.current_revision(), database_key_index); + old_memo.revisions.accumulated_inputs.store(inputs); if is_provisional { old_memo @@ -475,11 +465,7 @@ where continue 'cycle; } } - - break 'cycle VerifyResult::Unchanged( - InputAccumulatedValues::Empty, - cycle_heads, - ); + break 'cycle VerifyResult::Unchanged(inputs, cycle_heads); } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 5e28d9265..1720bfe8b 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -5,7 +5,6 @@ use std::fmt::{Debug, Formatter}; use std::ptr::NonNull; use std::sync::atomic::Ordering; -use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::cycle::{CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}; use crate::function::{Configuration, IngredientImpl}; use crate::key::DatabaseKeyIndex; @@ -218,12 +217,14 @@ impl Memo { /// Mark memo as having been verified in the `revision_now`, which should /// be the current revision. + /// The caller is responsible to update the memo's `accumulated` state if heir accumulated + /// values have changed since. + #[inline] pub(super) fn mark_as_verified( &self, db: &Db, revision_now: Revision, database_key_index: DatabaseKeyIndex, - accumulated: InputAccumulatedValues, ) { db.salsa_event(&|| { Event::new(EventKind::DidValidateMemoizedValue { @@ -232,7 +233,6 @@ impl Memo { }); self.verified_at.store(revision_now); - self.revisions.accumulated_inputs.store(accumulated); } pub(super) fn mark_outputs_as_verified( diff --git a/src/function/specify.rs b/src/function/specify.rs index 4c1589c58..4d5c7169a 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -134,11 +134,9 @@ where } let database_key_index = self.database_key_index(key); - memo.mark_as_verified( - db, - zalsa.current_revision(), - database_key_index, - InputAccumulatedValues::Empty, - ); + memo.mark_as_verified(db, zalsa.current_revision(), database_key_index); + memo.revisions + .accumulated_inputs + .store(InputAccumulatedValues::Empty); } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index c93d1efa1..c3ac915ae 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -5,14 +5,12 @@ use std::sync::atomic::AtomicBool; use rustc_hash::FxHashMap; use tracing::debug; -use crate::accumulator::accumulated_map::{ - AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues, -}; +use crate::accumulator::accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues}; use crate::active_query::QueryStack; use crate::cycle::CycleHeads; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; -use crate::runtime::{Stamp, StampedValue}; +use crate::runtime::Stamp; use crate::table::{PageIndex, Slot, Table}; use crate::tracked_struct::{Disambiguator, Identity, IdentityHash, IdentityMap}; use crate::zalsa::IngredientIndex; @@ -159,12 +157,14 @@ impl ZalsaLocal { } /// Register that currently active query reads the given input + #[inline] pub(crate) fn report_tracked_read( &self, input: DatabaseKeyIndex, durability: Durability, changed_at: Revision, - accumulated: InputAccumulatedValues, + has_accumulated: bool, + accumulated_inputs: &AtomicInputAccumulatedValues, cycle_heads: &CycleHeads, ) { debug!( @@ -173,7 +173,14 @@ impl ZalsaLocal { ); self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { - top_query.add_read(input, durability, changed_at, accumulated, cycle_heads); + top_query.add_read( + input, + durability, + changed_at, + has_accumulated, + accumulated_inputs, + cycle_heads, + ); } }) } @@ -345,33 +352,6 @@ impl QueryRevisions { cycle_heads: CycleHeads::initial(query), } } - - pub(crate) fn stamped_value(&self, value: V) -> StampedValue { - self.stamp_template().stamp(value) - } - - pub(crate) fn stamp_template(&self) -> StampTemplate { - StampTemplate { - durability: self.durability, - changed_at: self.changed_at, - } - } -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub(crate) struct StampTemplate { - durability: Durability, - changed_at: Revision, -} - -impl StampTemplate { - pub(crate) fn stamp(self, value: V) -> StampedValue { - StampedValue { - value, - durability: self.durability, - changed_at: self.changed_at, - } - } } /// Tracks the way that a memoized value for a query was created.