diff --git a/crates/trie/trie/src/trie_cursor/in_memory.rs b/crates/trie/trie/src/trie_cursor/in_memory.rs index 7951a6b791c..5fde874017e 100644 --- a/crates/trie/trie/src/trie_cursor/in_memory.rs +++ b/crates/trie/trie/src/trie_cursor/in_memory.rs @@ -56,30 +56,61 @@ where pub struct InMemoryTrieCursor<'a, C> { /// The underlying cursor. cursor: C, - /// Whether the underlying cursor should be ignored (when storage trie was wiped). - cursor_wiped: bool, - /// Entry that `cursor` is currently pointing to. - cursor_entry: Option<(Nibbles, BranchNodeCompact)>, + /// Tracks whether the DB cursor is available, positioned, or exhausted. + db_cursor_state: DbCursorState, /// Forward-only in-memory cursor over storage trie nodes. in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option>, /// The key most recently returned from the Cursor. last_key: Option, + #[cfg(debug_assertions)] /// Whether an initial seek was called. seeked: bool, /// Reference to the full trie updates. trie_updates: &'a TrieUpdatesSorted, } +#[derive(Debug)] +enum DbCursorState { + NeedsPosition, + Positioned((Nibbles, BranchNodeCompact)), + Exhausted, + Wiped, +} + +impl DbCursorState { + const fn new(cursor_wiped: bool) -> Self { + if cursor_wiped { + Self::Wiped + } else { + Self::NeedsPosition + } + } + + const fn entry(&self) -> Option<&(Nibbles, BranchNodeCompact)> { + match self { + Self::Positioned(entry) => Some(entry), + Self::NeedsPosition | Self::Exhausted | Self::Wiped => None, + } + } + + fn set_entry(&mut self, entry: Option<(Nibbles, BranchNodeCompact)>) { + *self = match entry { + Some(entry) => Self::Positioned(entry), + None => Self::Exhausted, + }; + } +} + impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { /// Create new account trie cursor which combines a DB cursor and the trie updates. pub fn new_account(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self { let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates.account_nodes_ref()); Self { cursor, - cursor_wiped: false, - cursor_entry: None, + db_cursor_state: DbCursorState::NeedsPosition, in_memory_cursor, last_key: None, + #[cfg(debug_assertions)] seeked: false, trie_updates, } @@ -96,10 +127,10 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { Self::get_storage_overlay(trie_updates, hashed_address); Self { cursor, - cursor_wiped, - cursor_entry: None, + db_cursor_state: DbCursorState::new(cursor_wiped), in_memory_cursor, last_key: None, + #[cfg(debug_assertions)] seeked: false, trie_updates, } @@ -119,7 +150,7 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { /// Returns a mutable reference to the underlying cursor if it's not wiped, None otherwise. fn get_cursor_mut(&mut self) -> Option<&mut C> { - (!self.cursor_wiped).then_some(&mut self.cursor) + (!matches!(self.db_cursor_state, DbCursorState::Wiped)).then_some(&mut self.cursor) } /// Asserts that the next entry to be returned from the cursor is not previous to the last entry @@ -135,31 +166,38 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { self.last_key = next_key; } - /// Seeks the `cursor_entry` field of the struct using the cursor. + /// Positions the DB cursor state using the underlying cursor when needed. fn cursor_seek(&mut self, key: Nibbles) -> Result<(), DatabaseError> { // Only seek if: // 1. We have a cursor entry and need to seek forward (entry.0 < key), OR - // 2. We have no cursor entry and haven't seeked yet (!self.seeked) - let should_seek = match self.cursor_entry.as_ref() { - Some(entry) => entry.0 < key, - None => !self.seeked, + // 2. The DB cursor needs to be positioned. + let should_seek = match &self.db_cursor_state { + DbCursorState::NeedsPosition => true, + DbCursorState::Positioned((entry_key, _)) => entry_key < &key, + DbCursorState::Exhausted | DbCursorState::Wiped => false, }; if should_seek { - self.cursor_entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten(); + let entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten(); + self.db_cursor_state.set_entry(entry); } Ok(()) } - /// Seeks the `cursor_entry` field of the struct to the subsequent entry using the cursor. + /// Advances the DB cursor state to the subsequent entry using the underlying cursor. fn cursor_next(&mut self) -> Result<(), DatabaseError> { - debug_assert!(self.seeked); + #[cfg(debug_assertions)] + { + debug_assert!(self.seeked); + debug_assert!(!matches!(self.db_cursor_state, DbCursorState::NeedsPosition)); + } - // If the previous entry is `None`, and we've done a seek previously, then the cursor is - // exhausted and we shouldn't call `next` again. - if self.cursor_entry.is_some() { - self.cursor_entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten(); + // Exhausted and wiped states are stable; only advance if the DB cursor currently points to + // an entry. + if matches!(self.db_cursor_state, DbCursorState::Positioned(_)) { + let entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten(); + self.db_cursor_state.set_entry(entry); } Ok(()) @@ -172,9 +210,12 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { /// node. fn choose_next_entry(&mut self) -> Result, DatabaseError> { loop { - match (self.in_memory_cursor.current().cloned(), &self.cursor_entry) { + let mem_entry = self.in_memory_cursor.current().cloned(); + let db_entry = self.db_cursor_state.entry(); + + match (mem_entry, db_entry) { (Some((mem_key, None)), _) - if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) => + if db_entry.is_none_or(|(db_key, _)| &mem_key < db_key) => { // If overlay has a removed node but DB cursor is exhausted or ahead of the // in-memory cursor then move ahead in-memory, as there might be further @@ -188,7 +229,7 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { self.cursor_next()?; } (Some((mem_key, Some(node))), _) - if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) => + if db_entry.is_none_or(|(db_key, _)| &mem_key <= db_key) => { // If overlay returns a node prior to the DB's node, or the DB is exhausted, // then we return the overlay's node. @@ -198,7 +239,7 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { // - mem_key > db_key // - overlay is exhausted // Return the db_entry. If DB is also exhausted then this returns None. - _ => return Ok(self.cursor_entry.clone()), + _ => return Ok(db_entry.cloned()), } } } @@ -209,16 +250,38 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { &mut self, key: Nibbles, ) -> Result, DatabaseError> { - self.cursor_seek(key)?; let mem_entry = self.in_memory_cursor.seek(&key); - self.seeked = true; + if let Some((mem_key, entry_inner)) = mem_entry && + *mem_key == key + { + #[cfg(debug_assertions)] + { + self.seeked = true; + } - let entry = match (mem_entry, &self.cursor_entry) { - (Some((mem_key, entry_inner)), _) if *mem_key == key => { - entry_inner.clone().map(|node| (key, node)) + // An exact overlay hit can move the logical cursor ahead without touching the DB. If + // the DB cursor was still behind this key, force a re-seek before the next DB-backed + // operation so `next()` cannot return a stale earlier entry. + if matches!(&self.db_cursor_state, DbCursorState::Positioned((db_key, _)) if db_key < &key) + { + self.db_cursor_state = DbCursorState::NeedsPosition; } - (_, Some((db_key, node))) if db_key == &key => Some((key, node.clone())), + + let entry = entry_inner.clone().map(|node| (key, node)); + self.set_last_key(&entry); + return Ok(entry) + } + + self.cursor_seek(key)?; + + #[cfg(debug_assertions)] + { + self.seeked = true; + } + + let entry = match self.db_cursor_state.entry() { + Some((db_key, node)) if db_key == &key => Some((key, node.clone())), _ => None, }; @@ -233,7 +296,10 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { self.cursor_seek(key)?; self.in_memory_cursor.seek(&key); - self.seeked = true; + #[cfg(debug_assertions)] + { + self.seeked = true; + } let entry = self.choose_next_entry()?; self.set_last_key(&entry); @@ -241,7 +307,10 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { } fn next(&mut self) -> Result, DatabaseError> { - debug_assert!(self.seeked, "Cursor must be seek'd before next is called"); + #[cfg(debug_assertions)] + { + debug_assert!(self.seeked, "Cursor must be seek'd before next is called"); + } // A `last_key` of `None` indicates that the cursor is exhausted. let Some(last_key) = self.last_key else { @@ -256,7 +325,11 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { self.in_memory_cursor.first_after(&last_key); } - if let Some((key, _)) = &self.cursor_entry && + if matches!(self.db_cursor_state, DbCursorState::NeedsPosition) { + self.cursor_seek(last_key)?; + } + + if let Some((key, _)) = self.db_cursor_state.entry() && key == &last_key { self.cursor_next()?; @@ -275,23 +348,15 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { } fn reset(&mut self) { - let Self { - cursor, - cursor_wiped, - cursor_entry, - in_memory_cursor, - last_key, - seeked, - trie_updates: _, - } = self; + self.cursor.reset(); + self.in_memory_cursor.reset(); - cursor.reset(); - in_memory_cursor.reset(); - - *cursor_wiped = false; - *cursor_entry = None; - *last_key = None; - *seeked = false; + self.db_cursor_state = DbCursorState::NeedsPosition; + self.last_key = None; + #[cfg(debug_assertions)] + { + self.seeked = false; + } } } @@ -299,8 +364,10 @@ impl TrieStorageCursor for InMemoryTrieCursor<'_, C> { fn set_hashed_address(&mut self, hashed_address: B256) { self.reset(); self.cursor.set_hashed_address(hashed_address); - (self.in_memory_cursor, self.cursor_wiped) = + let (in_memory_cursor, cursor_wiped) = Self::get_storage_overlay(self.trie_updates, hashed_address); + self.in_memory_cursor = in_memory_cursor; + self.db_cursor_state = DbCursorState::new(cursor_wiped); } } @@ -507,7 +574,7 @@ mod tests { let db_nodes_map: BTreeMap = db_nodes.into_iter().collect(); let db_nodes_arc = Arc::new(db_nodes_map); let visited_keys = Arc::new(Mutex::new(Vec::new())); - let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys); + let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys.clone()); let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default()); let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates); @@ -520,6 +587,7 @@ mod tests { BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None) )) ); + assert!(visited_keys.lock().is_empty(), "exact overlay hit should not touch the DB cursor"); let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap(); assert_eq!(