Skip to content
172 changes: 120 additions & 52 deletions crates/trie/trie/src/trie_cursor/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,61 @@ where
pub struct InMemoryTrieCursor<'a, C> {
/// The underlying cursor.
cursor: C,
/// Whether the underlying cursor should be ignored (when storage trie was wiped).
cursor_wiped: bool,
/// Entry that `cursor` is currently pointing to.
cursor_entry: Option<(Nibbles, BranchNodeCompact)>,
/// Tracks whether the DB cursor is available, positioned, or exhausted.
db_cursor_state: DbCursorState,
/// Forward-only in-memory cursor over storage trie nodes.
in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
/// The key most recently returned from the Cursor.
last_key: Option<Nibbles>,
#[cfg(debug_assertions)]
/// Whether an initial seek was called.
seeked: bool,
/// Reference to the full trie updates.
trie_updates: &'a TrieUpdatesSorted,
}

#[derive(Debug)]
enum DbCursorState {
NeedsPosition,
Positioned((Nibbles, BranchNodeCompact)),
Exhausted,
Wiped,
}

impl DbCursorState {
const fn new(cursor_wiped: bool) -> Self {
if cursor_wiped {
Self::Wiped
} else {
Self::NeedsPosition
}
}

const fn entry(&self) -> Option<&(Nibbles, BranchNodeCompact)> {
match self {
Self::Positioned(entry) => Some(entry),
Self::NeedsPosition | Self::Exhausted | Self::Wiped => None,
}
}

fn set_entry(&mut self, entry: Option<(Nibbles, BranchNodeCompact)>) {
*self = match entry {
Some(entry) => Self::Positioned(entry),
None => Self::Exhausted,
};
}
}

impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
/// Create new account trie cursor which combines a DB cursor and the trie updates.
pub fn new_account(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self {
let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates.account_nodes_ref());
Self {
cursor,
cursor_wiped: false,
cursor_entry: None,
db_cursor_state: DbCursorState::NeedsPosition,
in_memory_cursor,
last_key: None,
#[cfg(debug_assertions)]
seeked: false,
trie_updates,
}
Expand All @@ -96,10 +127,10 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
Self::get_storage_overlay(trie_updates, hashed_address);
Self {
cursor,
cursor_wiped,
cursor_entry: None,
db_cursor_state: DbCursorState::new(cursor_wiped),
in_memory_cursor,
last_key: None,
#[cfg(debug_assertions)]
seeked: false,
trie_updates,
}
Expand All @@ -119,7 +150,7 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {

/// Returns a mutable reference to the underlying cursor if it's not wiped, None otherwise.
fn get_cursor_mut(&mut self) -> Option<&mut C> {
(!self.cursor_wiped).then_some(&mut self.cursor)
(!matches!(self.db_cursor_state, DbCursorState::Wiped)).then_some(&mut self.cursor)
}

/// Asserts that the next entry to be returned from the cursor is not previous to the last entry
Expand All @@ -135,31 +166,38 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
self.last_key = next_key;
}

/// Seeks the `cursor_entry` field of the struct using the cursor.
/// Positions the DB cursor state using the underlying cursor when needed.
fn cursor_seek(&mut self, key: Nibbles) -> Result<(), DatabaseError> {
// Only seek if:
// 1. We have a cursor entry and need to seek forward (entry.0 < key), OR
// 2. We have no cursor entry and haven't seeked yet (!self.seeked)
let should_seek = match self.cursor_entry.as_ref() {
Some(entry) => entry.0 < key,
None => !self.seeked,
// 2. The DB cursor needs to be positioned.
let should_seek = match &self.db_cursor_state {
DbCursorState::NeedsPosition => true,
DbCursorState::Positioned((entry_key, _)) => entry_key < &key,
DbCursorState::Exhausted | DbCursorState::Wiped => false,
};

if should_seek {
self.cursor_entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten();
let entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten();
self.db_cursor_state.set_entry(entry);
}

Ok(())
}

/// Seeks the `cursor_entry` field of the struct to the subsequent entry using the cursor.
/// Advances the DB cursor state to the subsequent entry using the underlying cursor.
fn cursor_next(&mut self) -> Result<(), DatabaseError> {
debug_assert!(self.seeked);
#[cfg(debug_assertions)]
{
debug_assert!(self.seeked);
debug_assert!(!matches!(self.db_cursor_state, DbCursorState::NeedsPosition));
}

// If the previous entry is `None`, and we've done a seek previously, then the cursor is
// exhausted and we shouldn't call `next` again.
if self.cursor_entry.is_some() {
self.cursor_entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten();
// Exhausted and wiped states are stable; only advance if the DB cursor currently points to
// an entry.
if matches!(self.db_cursor_state, DbCursorState::Positioned(_)) {
let entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten();
self.db_cursor_state.set_entry(entry);
}

Ok(())
Expand All @@ -172,9 +210,12 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
/// node.
fn choose_next_entry(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
loop {
match (self.in_memory_cursor.current().cloned(), &self.cursor_entry) {
let mem_entry = self.in_memory_cursor.current().cloned();
let db_entry = self.db_cursor_state.entry();

match (mem_entry, db_entry) {
(Some((mem_key, None)), _)
Comment thread
mediocregopher marked this conversation as resolved.
if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
if db_entry.is_none_or(|(db_key, _)| &mem_key < db_key) =>
{
// If overlay has a removed node but DB cursor is exhausted or ahead of the
// in-memory cursor then move ahead in-memory, as there might be further
Expand All @@ -188,7 +229,7 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
self.cursor_next()?;
}
(Some((mem_key, Some(node))), _)
Comment thread
mediocregopher marked this conversation as resolved.
if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
if db_entry.is_none_or(|(db_key, _)| &mem_key <= db_key) =>
{
// If overlay returns a node prior to the DB's node, or the DB is exhausted,
// then we return the overlay's node.
Expand All @@ -198,7 +239,7 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
// - mem_key > db_key
// - overlay is exhausted
// Return the db_entry. If DB is also exhausted then this returns None.
_ => return Ok(self.cursor_entry.clone()),
_ => return Ok(db_entry.cloned()),
}
}
}
Expand All @@ -209,16 +250,38 @@ impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
self.cursor_seek(key)?;
let mem_entry = self.in_memory_cursor.seek(&key);

self.seeked = true;
if let Some((mem_key, entry_inner)) = mem_entry &&
*mem_key == key
{
#[cfg(debug_assertions)]
{
self.seeked = true;
}

let entry = match (mem_entry, &self.cursor_entry) {
(Some((mem_key, entry_inner)), _) if *mem_key == key => {
entry_inner.clone().map(|node| (key, node))
// An exact overlay hit can move the logical cursor ahead without touching the DB. If
// the DB cursor was still behind this key, force a re-seek before the next DB-backed
// operation so `next()` cannot return a stale earlier entry.
if matches!(&self.db_cursor_state, DbCursorState::Positioned((db_key, _)) if db_key < &key)
{
self.db_cursor_state = DbCursorState::NeedsPosition;
}
(_, Some((db_key, node))) if db_key == &key => Some((key, node.clone())),

let entry = entry_inner.clone().map(|node| (key, node));
self.set_last_key(&entry);
return Ok(entry)
}

self.cursor_seek(key)?;

#[cfg(debug_assertions)]
{
self.seeked = true;
}

let entry = match self.db_cursor_state.entry() {
Some((db_key, node)) if db_key == &key => Some((key, node.clone())),
_ => None,
};

Expand All @@ -233,15 +296,21 @@ impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
self.cursor_seek(key)?;
self.in_memory_cursor.seek(&key);

self.seeked = true;
#[cfg(debug_assertions)]
{
self.seeked = true;
}

let entry = self.choose_next_entry()?;
self.set_last_key(&entry);
Ok(entry)
}

fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
#[cfg(debug_assertions)]
{
debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
}

// A `last_key` of `None` indicates that the cursor is exhausted.
let Some(last_key) = self.last_key else {
Expand All @@ -256,7 +325,11 @@ impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
self.in_memory_cursor.first_after(&last_key);
}

if let Some((key, _)) = &self.cursor_entry &&
if matches!(self.db_cursor_state, DbCursorState::NeedsPosition) {
self.cursor_seek(last_key)?;
}

if let Some((key, _)) = self.db_cursor_state.entry() &&
key == &last_key
{
self.cursor_next()?;
Expand All @@ -275,32 +348,26 @@ impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
}

fn reset(&mut self) {
let Self {
cursor,
cursor_wiped,
cursor_entry,
in_memory_cursor,
last_key,
seeked,
trie_updates: _,
} = self;
self.cursor.reset();
self.in_memory_cursor.reset();

cursor.reset();
in_memory_cursor.reset();

*cursor_wiped = false;
*cursor_entry = None;
*last_key = None;
*seeked = false;
self.db_cursor_state = DbCursorState::NeedsPosition;
self.last_key = None;
#[cfg(debug_assertions)]
{
self.seeked = false;
}
}
}

impl<C: TrieStorageCursor> TrieStorageCursor for InMemoryTrieCursor<'_, C> {
fn set_hashed_address(&mut self, hashed_address: B256) {
self.reset();
self.cursor.set_hashed_address(hashed_address);
(self.in_memory_cursor, self.cursor_wiped) =
let (in_memory_cursor, cursor_wiped) =
Self::get_storage_overlay(self.trie_updates, hashed_address);
self.in_memory_cursor = in_memory_cursor;
self.db_cursor_state = DbCursorState::new(cursor_wiped);
}
}

Expand Down Expand Up @@ -507,7 +574,7 @@ mod tests {
let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
let db_nodes_arc = Arc::new(db_nodes_map);
let visited_keys = Arc::new(Mutex::new(Vec::new()));
let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys.clone());

let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
Expand All @@ -520,6 +587,7 @@ mod tests {
BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
))
);
assert!(visited_keys.lock().is_empty(), "exact overlay hit should not touch the DB cursor");

let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
assert_eq!(
Expand Down
Loading