diff --git a/Cargo.lock b/Cargo.lock index 04d4a09c792..297410b8989 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -399,6 +399,15 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "fastrand" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +dependencies = [ + "instant", +] + [[package]] name = "fastrlp" version = "0.1.3" @@ -747,6 +756,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + [[package]] name = "itoa" version = "1.0.3" @@ -1259,6 +1277,15 @@ version = "0.6.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi", +] + [[package]] name = "reth" version = "0.1.0" @@ -1321,8 +1348,14 @@ name = "reth-stages" version = "0.1.0" dependencies = [ "async-trait", + "reth-db", "reth-primitives", + "tempfile", "thiserror", + "tokio", + "tokio-stream", + "tracing", + "tracing-futures", ] [[package]] @@ -1609,6 +1642,20 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +dependencies = [ + "cfg-if", + "fastrand", + "libc", + "redox_syscall", + "remove_dir_all", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.37" diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index e635d30fab9..0189c5803f0 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -9,5 +9,14 @@ description = "Staged syncing primitives used in reth." [dependencies] reth-primitives = { path = "../primitives" } +reth-db = { path = "../db" } async-trait = "0.1.57" -thiserror = "1.0.37" \ No newline at end of file +thiserror = "1.0.37" +tracing = "0.1.36" +tracing-futures = "0.2.5" +tokio = { version = "1.21.2", features = ["sync"] } + +[dev-dependencies] +tokio = { version = "*", features = ["rt", "sync", "macros"] } +tokio-stream = "0.1.10" +tempfile = "3.3.0" \ No newline at end of file diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 7edf87e0829..99252837a9a 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -9,14 +9,16 @@ //! See [Stage] and [Pipeline]. use async_trait::async_trait; +use reth_db::mdbx; use reth_primitives::U64; +use std::fmt::Display; use thiserror::Error; mod pipeline; pub use pipeline::*; /// Stage execution input, see [Stage::execute]. -#[derive(Clone, Copy, Debug)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct ExecInput { /// The stage that was run before the current stage and the block number it reached. pub previous_stage: Option<(StageId, U64)>, @@ -25,7 +27,7 @@ pub struct ExecInput { } /// Stage unwind input, see [Stage::unwind]. -#[derive(Clone, Copy, Debug)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct UnwindInput { /// The current highest block of the stage. pub stage_progress: U64, @@ -36,7 +38,7 @@ pub struct UnwindInput { } /// The output of a stage execution. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct ExecOutput { /// How far the stage got. pub stage_progress: U64, @@ -47,7 +49,7 @@ pub struct ExecOutput { } /// The output of a stage unwinding. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct UnwindOutput { /// The block at which the stage has unwound to. pub stage_progress: U64, @@ -59,19 +61,16 @@ pub enum StageError { /// The stage encountered a state validation error. /// /// TODO: This depends on the consensus engine and should include the validation failure reason - #[error("Stage encountered a validation error.")] - Validation, + #[error("Stage encountered a validation error in block {block}.")] + Validation { + /// The block that failed validation. + block: U64, + }, /// The stage encountered an internal error. #[error(transparent)] Internal(Box), } -/// The ID of a stage. -/// -/// Each stage ID must be unique. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct StageId(pub &'static str); - /// A stage is a segmented part of the syncing process of the node. /// /// Each stage takes care of a well-defined task, such as downloading headers or executing @@ -82,26 +81,77 @@ pub struct StageId(pub &'static str); /// /// Stages are executed as part of a pipeline where they are executed serially. #[async_trait] -pub trait Stage { +pub trait Stage<'db, E>: Send + Sync +where + E: mdbx::EnvironmentKind, +{ /// Get the ID of the stage. /// /// Stage IDs must be unique. fn id(&self) -> StageId; /// Execute the stage. - async fn execute( + async fn execute<'tx>( &mut self, - tx: &mut dyn DbTransaction, + tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, input: ExecInput, - ) -> Result; + ) -> Result + where + 'db: 'tx; /// Unwind the stage. - async fn unwind( + async fn unwind<'tx>( &mut self, - tx: &mut dyn DbTransaction, + tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, input: UnwindInput, - ) -> Result>; + ) -> Result> + where + 'db: 'tx; } -/// TODO: Stand-in for database-related abstractions. -pub trait DbTransaction {} +/// The ID of a stage. +/// +/// Each stage ID must be unique. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageId(pub &'static str); + +impl Display for StageId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl StageId { + /// Get the last committed progress of this stage. + pub fn get_progress<'db, K, E>( + &self, + tx: &mdbx::Transaction<'db, K, E>, + ) -> Result, mdbx::Error> + where + K: mdbx::TransactionKind, + E: mdbx::EnvironmentKind, + { + // TODO: Clean up when we get better database abstractions + let bytes: Option> = tx.get(&tx.open_db(Some("SyncStage"))?, self.0.as_ref())?; + + Ok(bytes.map(|b| U64::from_big_endian(b.as_ref()))) + } + + /// Save the progress of this stage. + pub fn save_progress<'db, E>( + &self, + tx: &mdbx::Transaction<'db, mdbx::RW, E>, + block: U64, + ) -> Result<(), mdbx::Error> + where + E: mdbx::EnvironmentKind, + { + // TODO: Clean up when we get better database abstractions + tx.put( + &tx.open_db(Some("SyncStage"))?, + self.0, + block.0[0].to_be_bytes(), + mdbx::WriteFlags::UPSERT, + ) + } +} diff --git a/crates/stages/src/pipeline.rs b/crates/stages/src/pipeline.rs index 7fb874f1497..39fef096c90 100644 --- a/crates/stages/src/pipeline.rs +++ b/crates/stages/src/pipeline.rs @@ -1,11 +1,16 @@ -use crate::Stage; +use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; +use reth_db::mdbx; use reth_primitives::U64; use std::fmt::{Debug, Formatter}; +use tokio::sync::mpsc::Sender; +use tracing::*; -#[allow(dead_code)] -struct QueuedStage { +struct QueuedStage<'db, E> +where + E: mdbx::EnvironmentKind, +{ /// The actual stage to execute. - stage: Box, + stage: Box>, /// The unwind priority of the stage. unwind_priority: usize, /// Whether or not this stage can only execute when we reach what we believe to be the tip of @@ -21,56 +26,78 @@ struct QueuedStage { /// tip. /// /// After the entire pipeline has been run, it will run again unless asked to stop (see -/// [Pipeline::set_exit_after_sync]). +/// [Pipeline::set_max_block]). /// /// # Unwinding /// /// In case of a validation error (as determined by the consensus engine) in one of the stages, the /// pipeline will unwind the stages according to their unwind priority. It is also possible to -/// request an unwind manually (see [Pipeline::start_with_unwind]). +/// request an unwind manually (see [Pipeline::unwind]). /// /// The unwind priority is set with [Pipeline::push_with_unwind_priority]. Stages with higher unwind /// priorities are unwound first. -#[derive(Default)] -pub struct Pipeline { - stages: Vec, - unwind_to: Option, +pub struct Pipeline<'db, E> +where + E: mdbx::EnvironmentKind, +{ + stages: Vec>, max_block: Option, - exit_after_sync: bool, + events_sender: Option>, } -impl Debug for Pipeline { +impl<'db, E> Default for Pipeline<'db, E> +where + E: mdbx::EnvironmentKind, +{ + fn default() -> Self { + Self { stages: Vec::new(), max_block: None, events_sender: None } + } +} + +impl<'db, E> Debug for Pipeline<'db, E> +where + E: mdbx::EnvironmentKind, +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Pipeline") - .field("unwind_to", &self.unwind_to) - .field("max_block", &self.max_block) - .field("exit_after_sync", &self.exit_after_sync) - .finish() + f.debug_struct("Pipeline").field("max_block", &self.max_block).finish() } } -impl Pipeline { +impl<'db, E> Pipeline<'db, E> +where + E: mdbx::EnvironmentKind, +{ + /// Create a new pipeline. + pub fn new() -> Self { + Default::default() + } + + /// Create a new pipeline with a channel for receiving events (see [PipelineEvent]). + pub fn new_with_channel(sender: Sender) -> Self { + Self::new().set_channel(sender) + } + /// Add a stage to the pipeline. /// /// # Unwinding /// /// The unwind priority is set to 0. - pub fn push(&mut self, stage: S, require_tip: bool) -> &mut Self + pub fn push(self, stage: S, require_tip: bool) -> Self where - S: Stage + 'static, + S: Stage<'db, E> + 'static, { self.push_with_unwind_priority(stage, require_tip, 0) } /// Add a stage to the pipeline, specifying the unwind priority. pub fn push_with_unwind_priority( - &mut self, + mut self, stage: S, require_tip: bool, unwind_priority: usize, - ) -> &mut Self + ) -> Self where - S: Stage + 'static, + S: Stage<'db, E> + 'static, { self.stages.push(QueuedStage { stage: Box::new(stage), require_tip, unwind_priority }); self @@ -79,25 +106,621 @@ impl Pipeline { /// Set the target block. /// /// Once this block is reached, syncing will stop. - pub fn set_max_block(&mut self, block: Option) -> &mut Self { + pub fn set_max_block(mut self, block: Option) -> Self { self.max_block = block; self } - /// Start the pipeline by unwinding to the specified block. - pub fn start_with_unwind(&mut self, unwind_to: Option) -> &mut Self { - self.unwind_to = unwind_to; + /// Set a channel the pipeline will transmit events over (see [PipelineEvent]). + pub fn set_channel(mut self, sender: Sender) -> Self { + self.events_sender = Some(sender); self } - /// Control whether the pipeline should exit after syncing. - pub fn set_exit_after_sync(&mut self, exit: bool) -> &mut Self { - self.exit_after_sync = exit; - self + /// Run the pipeline. + pub async fn run( + &mut self, + db: &'db mdbx::Environment, + ) -> Result<(), Box> { + let mut previous_stage = None; + let mut minimum_progress: Option = None; + let mut maximum_progress: Option = None; + let mut reached_tip_flag = true; + + 'run: loop { + let mut tx = db.begin_rw_txn()?; + for (_, QueuedStage { stage, require_tip, .. }) in self.stages.iter_mut().enumerate() { + let stage_id = stage.id(); + let block_reached = loop { + let prev_progress = stage_id.get_progress(&tx)?; + + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Running { stage_id, stage_progress: prev_progress }) + .await? + } + + let reached_virtual_tip = maximum_progress + .zip(self.max_block) + .map_or(false, |(progress, target)| progress >= target); + + // Execute stage + let output = async { + if !reached_tip_flag && *require_tip && !reached_virtual_tip { + info!("Tip not reached, skipping."); + + // Stage requires us to reach the tip of the chain first, but we have + // not. + Ok(ExecOutput { + stage_progress: prev_progress.unwrap_or_default(), + done: true, + reached_tip: false, + }) + } else if prev_progress + .zip(self.max_block) + .map_or(false, |(prev_progress, target)| prev_progress >= target) + { + info!("Stage reached maximum block, skipping."); + // We reached the maximum block, so we skip the stage + Ok(ExecOutput { + stage_progress: prev_progress.unwrap_or_default(), + done: true, + reached_tip: true, + }) + } else { + stage + .execute( + &mut tx, + ExecInput { previous_stage, stage_progress: prev_progress }, + ) + .await + } + } + .instrument(info_span!("Running", stage = %stage_id)) + .await; + + match output { + Ok(out @ ExecOutput { stage_progress, done, reached_tip }) => { + debug!(stage = %stage_id, %stage_progress, %done, "Stage made progress"); + stage_id.save_progress(&tx, stage_progress)?; + + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Ran { stage_id, result: Some(out.clone()) }) + .await? + } + + // TODO: Make the commit interval configurable + tx.commit()?; + tx = db.begin_rw_txn()?; + + // TODO: Clean up + if let Some(min) = &mut minimum_progress { + *min = std::cmp::min(*min, stage_progress); + } else { + minimum_progress = Some(stage_progress); + } + if let Some(max) = &mut maximum_progress { + *max = std::cmp::max(*max, stage_progress); + } else { + maximum_progress = Some(stage_progress); + } + + if done { + reached_tip_flag = reached_tip; + break stage_progress + } + } + Err(StageError::Validation { block }) => { + debug!(stage = %stage_id, bad_block = %block, "Stage encountered a validation error."); + + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Ran { stage_id, result: None }).await? + } + + // We unwind because of a validation error. If the unwind itself fails, + // we bail entirely, otherwise we restart the execution loop from the + // beginning. + // + // Note on the drop: The transaction needs to be dropped in order for + // unwind to create a new one. Dropping the + // transaction will abort it; there is no + // other way currently to abort the transaction. It will be re-created + // if the loop restarts. + drop(tx); + match self + .unwind(db, prev_progress.unwrap_or_default(), Some(block)) + .await + { + Ok(()) => continue 'run, + Err(e) => return Err(e), + } + } + Err(StageError::Internal(e)) => { + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Ran { stage_id, result: None }).await? + } + + return Err(e) + } + } + }; + + // Set previous stage and continue on to next stage. + previous_stage = Some((stage_id, block_reached)); + } + tx.commit()?; + + // Check if we've reached our desired target block + if minimum_progress + .zip(self.max_block) + .map_or(false, |(progress, target)| progress >= target) + { + return Ok(()) + } + } } - /// Run the pipeline. - pub async fn run(&mut self) -> Result<(), Box> { - todo!() + /// Unwind the stages to the target block. + /// + /// If the unwind is due to a bad block the number of that block should be specified. + pub async fn unwind( + &mut self, + db: &'db mdbx::Environment, + to: U64, + bad_block: Option, + ) -> Result<(), Box> { + // Sort stages by unwind priority + let mut unwind_pipeline = { + let mut stages: Vec<_> = self.stages.iter_mut().enumerate().collect(); + stages.sort_by_key(|(id, stage)| { + if stage.unwind_priority > 0 { + (id - stage.unwind_priority, 0) + } else { + (*id, 1) + } + }); + stages.reverse(); + stages + }; + + // Unwind stages in reverse order of priority (i.e. higher priority = first) + let mut tx = db.begin_rw_txn()?; + for (_, QueuedStage { stage, .. }) in unwind_pipeline.iter_mut() { + let stage_id = stage.id(); + let mut stage_progress = stage_id.get_progress(&tx)?.unwrap_or_default(); + + let unwind: Result<(), Box> = async { + if stage_progress < to { + debug!(from = %stage_progress, %to, "Unwind point too far for stage"); + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Unwound { + stage_id, + result: Some(UnwindOutput { stage_progress }), + }) + .await? + } + + return Ok(()) + } + + debug!(from = %stage_progress, %to, ?bad_block, "Starting unwind"); + while stage_progress > to { + let input = UnwindInput { stage_progress, unwind_to: to, bad_block }; + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Unwinding { stage_id, input }).await? + } + + let output = stage.unwind(&mut tx, input).await; + match output { + Ok(unwind_output) => { + stage_progress = unwind_output.stage_progress; + stage_id.save_progress(&tx, stage_progress)?; + + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Unwound { + stage_id, + result: Some(unwind_output), + }) + .await? + } + } + Err(err) => { + if let Some(rx) = &self.events_sender { + rx.send(PipelineEvent::Unwound { stage_id, result: None }).await? + } + + return Err(err) + } + } + } + + Ok(()) + } + .instrument(info_span!("Unwinding", stage = %stage_id)) + .await; + unwind? + } + + tx.commit()?; + Ok(()) + } +} + +/// An event emitted by a [Pipeline]. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum PipelineEvent { + /// Emitted when a stage is about to be run. + Running { + /// The stage that is about to be run. + stage_id: StageId, + /// The previous checkpoint of the stage. + stage_progress: Option, + }, + /// Emitted when a stage has run a single time. + /// + /// It is possible for multiple of these events to be emitted over the duration of a pipeline's + /// execution: + /// - If the pipeline loops, the stage will be run again at some point + /// - If the stage exits early but has acknowledged that it is not entirely done + Ran { + /// The stage that was run. + stage_id: StageId, + /// The result of executing the stage. If it is None then an error was encountered. + result: Option, + }, + /// Emitted when a stage is about to be unwound. + Unwinding { + /// The stage that is about to be unwound. + stage_id: StageId, + /// The unwind parameters. + input: UnwindInput, + }, + /// Emitted when a stage has been unwound. + /// + /// It is possible for multiple of these events to be emitted over the duration of a pipeline's + /// execution, since other stages may ask the pipeline to unwind. + Unwound { + /// The stage that was unwound. + stage_id: StageId, + /// The result of unwinding the stage. If it is None then an error was encountered. + result: Option, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{StageId, UnwindOutput}; + use reth_db::mdbx; + use tempfile::tempdir; + use tokio::sync::mpsc::channel; + use tokio_stream::{wrappers::ReceiverStream, StreamExt}; + use utils::TestStage; + + /// Runs a simple pipeline. + #[tokio::test] + async fn run_pipeline() { + let (tx, rx) = channel(2); + let db = utils::test_db().expect("Could not open test database"); + + // Run pipeline + tokio::spawn(async move { + Pipeline::::new_with_channel(tx) + .push( + TestStage::new(StageId("A")).add_exec(Ok(ExecOutput { + stage_progress: 20.into(), + done: true, + reached_tip: true, + })), + false, + ) + .push( + TestStage::new(StageId("B")).add_exec(Ok(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + })), + false, + ) + .set_max_block(Some(10.into())) + .run(&db) + .await + }); + + // Check that the stages were run in order + assert_eq!( + ReceiverStream::new(rx).collect::>().await, + vec![ + PipelineEvent::Running { stage_id: StageId("A"), stage_progress: None }, + PipelineEvent::Ran { + stage_id: StageId("A"), + result: Some(ExecOutput { + stage_progress: 20.into(), + done: true, + reached_tip: true, + }), + }, + PipelineEvent::Running { stage_id: StageId("B"), stage_progress: None }, + PipelineEvent::Ran { + stage_id: StageId("B"), + result: Some(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + }), + }, + ] + ); + } + + /// Unwinds a simple pipeline. + #[tokio::test] + async fn unwind_pipeline() { + let (tx, rx) = channel(2); + let db = utils::test_db().expect("Could not open test database"); + + // Run pipeline + tokio::spawn(async move { + let mut pipeline = Pipeline::::new() + .push( + TestStage::new(StageId("A")) + .add_exec(Ok(ExecOutput { + stage_progress: 100.into(), + done: true, + reached_tip: true, + })) + .add_unwind(Ok(UnwindOutput { stage_progress: 1.into() })), + false, + ) + .push( + TestStage::new(StageId("B")) + .add_exec(Ok(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + })) + .add_unwind(Ok(UnwindOutput { stage_progress: 1.into() })), + false, + ) + .set_max_block(Some(10.into())); + + // Sync first + pipeline.run(&db).await.expect("Could not run pipeline"); + + // Unwind + pipeline + .set_channel(tx) + .unwind(&db, 1.into(), None) + .await + .expect("Could not unwind pipeline"); + }); + + // Check that the stages were unwound in reverse order + assert_eq!( + ReceiverStream::new(rx).collect::>().await, + vec![ + PipelineEvent::Unwinding { + stage_id: StageId("B"), + input: UnwindInput { + stage_progress: 10.into(), + unwind_to: 1.into(), + bad_block: None + } + }, + PipelineEvent::Unwound { + stage_id: StageId("B"), + result: Some(UnwindOutput { stage_progress: 1.into() }), + }, + PipelineEvent::Unwinding { + stage_id: StageId("A"), + input: UnwindInput { + stage_progress: 100.into(), + unwind_to: 1.into(), + bad_block: None + } + }, + PipelineEvent::Unwound { + stage_id: StageId("A"), + result: Some(UnwindOutput { stage_progress: 1.into() }), + }, + ] + ); + } + + /// Runs a pipeline that unwinds during sync. + /// + /// The flow is: + /// + /// - Stage A syncs to block 10 + /// - Stage B triggers an unwind, marking block 5 as bad + /// - Stage B unwinds to it's previous progress, block 0 but since it is still at block 0, it is + /// skipped entirely (there is nothing to unwind) + /// - Stage A unwinds to it's previous progress, block 0 + /// - Stage A syncs back up to block 10 + /// - Stage B syncs to block 10 + /// - The pipeline finishes + #[tokio::test] + async fn run_pipeline_with_unwind() { + let (tx, rx) = channel(2); + let db = utils::test_db().expect("Could not open test database"); + + // Run pipeline + tokio::spawn(async move { + Pipeline::::new() + .push( + TestStage::new(StageId("A")) + .add_exec(Ok(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + })) + .add_unwind(Ok(UnwindOutput { stage_progress: 0.into() })) + .add_exec(Ok(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + })), + false, + ) + .push( + TestStage::new(StageId("B")) + .add_exec(Err(StageError::Validation { block: 5.into() })) + .add_unwind(Ok(UnwindOutput { stage_progress: 0.into() })) + .add_exec(Ok(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + })), + false, + ) + .set_max_block(Some(10.into())) + .set_channel(tx) + .run(&db) + .await + .expect("Could not run pipeline"); + }); + + // Check that the stages were unwound in reverse order + assert_eq!( + ReceiverStream::new(rx).collect::>().await, + vec![ + PipelineEvent::Running { stage_id: StageId("A"), stage_progress: None }, + PipelineEvent::Ran { + stage_id: StageId("A"), + result: Some(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + }), + }, + PipelineEvent::Running { stage_id: StageId("B"), stage_progress: None }, + PipelineEvent::Ran { stage_id: StageId("B"), result: None }, + PipelineEvent::Unwinding { + stage_id: StageId("A"), + input: UnwindInput { + stage_progress: 10.into(), + unwind_to: 0.into(), + bad_block: Some(5.into()) + } + }, + PipelineEvent::Unwound { + stage_id: StageId("A"), + result: Some(UnwindOutput { stage_progress: 0.into() }), + }, + PipelineEvent::Running { stage_id: StageId("A"), stage_progress: Some(0.into()) }, + PipelineEvent::Ran { + stage_id: StageId("A"), + result: Some(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + }), + }, + PipelineEvent::Running { stage_id: StageId("B"), stage_progress: None }, + PipelineEvent::Ran { + stage_id: StageId("B"), + result: Some(ExecOutput { + stage_progress: 10.into(), + done: true, + reached_tip: true, + }), + }, + ] + ); + } + + mod utils { + use super::*; + use async_trait::async_trait; + use std::{collections::VecDeque, error::Error}; + + // TODO: This is... not great. + pub(crate) fn test_db() -> Result, mdbx::Error> { + const DB_TABLES: usize = 10; + + // Build environment + let mut builder = mdbx::Environment::::new(); + builder.set_max_dbs(DB_TABLES); + builder.set_geometry(mdbx::Geometry { + size: Some(0..usize::MAX), + growth_step: None, + shrink_threshold: None, + page_size: None, + }); + builder.set_rp_augment_limit(16 * 256 * 1024); + + // Open + let tempdir = tempdir().unwrap(); + let path = tempdir.path(); + std::fs::DirBuilder::new().recursive(true).create(path).unwrap(); + let db = builder.open(path)?; + + // Create tables + let tx = db.begin_rw_txn()?; + tx.create_db(Some("SyncStage"), mdbx::DatabaseFlags::default())?; + tx.commit()?; + + Ok(db) + } + + pub(crate) struct TestStage { + id: StageId, + exec_outputs: VecDeque>, + unwind_outputs: VecDeque>>, + } + + impl TestStage { + pub(crate) fn new(id: StageId) -> Self { + Self { id, exec_outputs: VecDeque::new(), unwind_outputs: VecDeque::new() } + } + + pub(crate) fn add_exec(mut self, output: Result) -> Self { + self.exec_outputs.push_back(output); + self + } + + pub(crate) fn add_unwind( + mut self, + output: Result>, + ) -> Self { + self.unwind_outputs.push_back(output); + self + } + } + + #[async_trait] + impl<'db, E> Stage<'db, E> for TestStage + where + E: mdbx::EnvironmentKind, + { + fn id(&self) -> StageId { + self.id + } + + async fn execute<'tx>( + &mut self, + _: &mut mdbx::Transaction<'tx, mdbx::RW, E>, + _: ExecInput, + ) -> Result + where + 'db: 'tx, + { + self.exec_outputs + .pop_front() + .unwrap_or_else(|| panic!("Test stage {} executed too many times.", self.id)) + } + + async fn unwind<'tx>( + &mut self, + _: &mut mdbx::Transaction<'tx, mdbx::RW, E>, + _: UnwindInput, + ) -> Result> + where + 'db: 'tx, + { + self.unwind_outputs + .pop_front() + .unwrap_or_else(|| panic!("Test stage {} unwound too many times.", self.id)) + } + } } }