diff --git a/crates/supervisor/core/src/chain_processor/handlers/invalidation.rs b/crates/supervisor/core/src/chain_processor/handlers/invalidation.rs index 616efe6d71..036d497a2c 100644 --- a/crates/supervisor/core/src/chain_processor/handlers/invalidation.rs +++ b/crates/supervisor/core/src/chain_processor/handlers/invalidation.rs @@ -315,6 +315,7 @@ mod tests { impl StorageRewinder for Db { fn rewind_log_storage(&self, to: &BlockNumHash) -> Result<(), StorageError>; fn rewind(&self, to: &BlockNumHash) -> Result<(), StorageError>; + fn rewind_to_source(&self, to: &BlockNumHash) -> Result, StorageError>; } ); diff --git a/crates/supervisor/core/src/chain_processor/handlers/safe_block.rs b/crates/supervisor/core/src/chain_processor/handlers/safe_block.rs index d234810ce6..abcba73d4a 100644 --- a/crates/supervisor/core/src/chain_processor/handlers/safe_block.rs +++ b/crates/supervisor/core/src/chain_processor/handlers/safe_block.rs @@ -39,7 +39,7 @@ where trace!( target: "supervisor::chain_processor", chain_id = self.chain_id, - block_number = derived_ref_pair.derived.number, + %derived_ref_pair, "Processing local safe derived block pair" ); @@ -382,6 +382,7 @@ mod tests { impl StorageRewinder for Db { fn rewind_log_storage(&self, to: &BlockNumHash) -> Result<(), StorageError>; fn rewind(&self, to: &BlockNumHash) -> Result<(), StorageError>; + fn rewind_to_source(&self, to: &BlockNumHash) -> Result, StorageError>; } ); diff --git a/crates/supervisor/core/src/reorg/handler.rs b/crates/supervisor/core/src/reorg/handler.rs index eb4c2924d3..6449a7f277 100644 --- a/crates/supervisor/core/src/reorg/handler.rs +++ b/crates/supervisor/core/src/reorg/handler.rs @@ -8,7 +8,7 @@ use kona_protocol::BlockInfo; use kona_supervisor_metrics::observe_metrics_for_result_async; use kona_supervisor_storage::{DbReader, StorageRewinder}; use std::{collections::HashMap, sync::Arc}; -use tracing::{info, warn}; +use tracing::{error, info}; /// Handles L1 reorg operations for multiple chains #[derive(Debug, Constructor)] @@ -94,7 +94,7 @@ where let results = future::join_all(handles).await; for result in results { if let Err(err) = result { - warn!(target: "supervisor::reorg_handler", %err, "Reorg task failed"); + error!(target: "supervisor::reorg_handler", %err, "Reorg task failed"); } } diff --git a/crates/supervisor/core/src/reorg/task.rs b/crates/supervisor/core/src/reorg/task.rs index b9d75b6aca..cbe8f6cc5f 100644 --- a/crates/supervisor/core/src/reorg/task.rs +++ b/crates/supervisor/core/src/reorg/task.rs @@ -20,6 +20,12 @@ pub(crate) struct ReorgTask { managed_node: Arc, } +#[derive(Debug)] +struct RewoundState { + source: BlockInfo, + derived: Option, +} + impl ReorgTask where C: ManagedNodeController + Send + Sync + 'static, @@ -28,6 +34,12 @@ where /// Processes reorg for a single chain. If the chain is consistent with the L1 chain, /// does nothing. pub(crate) async fn process_chain_reorg(&self) -> Result<(), ReorgHandlerError> { + trace!( + target: "supervisor::reorg_handler", + chain_id = %self.chain_id, + "Processing reorg for chain..." + ); + let latest_state = self.db.latest_derivation_state()?; // Find last valid source block for this chain @@ -59,11 +71,12 @@ where // record metrics if let Some(rewound_state) = rewound_state { - Metrics::record_block_depth( - self.chain_id, - latest_state.source.number - rewound_state.source.number, - latest_state.derived.number - rewound_state.derived.number, - ); + let l1_depth = latest_state.source.number - rewound_state.source.number; + let mut l2_depth = 0; + if let Some(derived) = rewound_state.derived { + l2_depth = latest_state.derived.number - derived.number; + } + Metrics::record_block_depth(self.chain_id, l1_depth, l2_depth); } Ok(()) } @@ -71,29 +84,35 @@ where async fn rewind_to_target_source( &self, rewind_target_source: BlockInfo, - ) -> Result { - // Get the derived block at the target source block - let rewind_target_derived = - self.db.latest_derived_block_at_source(rewind_target_source.id())?; + ) -> Result { + trace!( + target: "supervisor::reorg_handler", + chain_id = %self.chain_id, + rewind_target_source = rewind_target_source.number, + "Rewinding to target source block..." + ); - // rewind_to() method is inclusive, so we need to get the next block. - let rewind_to = self.db.get_block(rewind_target_derived.number + 1)?; // Call the rewinder to handle the DB rewinding - self.db.rewind(&rewind_to.id()).inspect_err(|err| { - warn!( - target: "supervisor::reorg_handler::db", - chain_id = %self.chain_id, - %err, - "Failed to rewind DB to derived block" - ); - })?; + let derived_block_rewounded = + self.db.rewind_to_source(&rewind_target_source.id()).inspect_err(|err| { + warn!( + target: "supervisor::reorg_handler::db", + chain_id = %self.chain_id, + %err, + "Failed to rewind DB to derived block" + ); + })?; - Ok(DerivedRefPair { source: rewind_target_source, derived: rewind_target_derived }) + Ok(RewoundState { source: rewind_target_source, derived: derived_block_rewounded }) } - async fn rewind_to_activation_block( - &self, - ) -> Result, ReorgHandlerError> { + async fn rewind_to_activation_block(&self) -> Result, ReorgHandlerError> { + trace!( + target: "supervisor::reorg_handler", + chain_id = %self.chain_id, + "Rewinding to activation block..." + ); + // If the rewind target is pre-interop, we need to rewind to the activation block match self.db.get_activation_block() { Ok(activation_block) => { @@ -106,9 +125,9 @@ where "Failed to rewind DB to activation block" ); })?; - Ok(Some(DerivedRefPair { + Ok(Some(RewoundState { source: activation_source_block, - derived: activation_block, + derived: Some(activation_block), })) } Err(StorageError::DatabaseNotInitialised) => { @@ -149,8 +168,9 @@ where return Ok(None); } - let mut common_ancestor = self.find_common_ancestor().await?; - let mut current_source = latest_state.source; + let common_ancestor = self.find_common_ancestor().await?; + let mut prev_source = latest_state.source; + let mut current_source = self.db.get_source_block(prev_source.number - 1)?; while current_source.number > common_ancestor.number { if current_source.number % 5 == 0 { @@ -170,15 +190,16 @@ where block_number = current_source.number, "Found canonical block as rewind target" ); - common_ancestor = current_source; break; } // Otherwise, walk back to the previous source block + prev_source = current_source; current_source = self.db.get_source_block(current_source.number - 1)?; } - Ok(Some(common_ancestor)) + // return the previous source block as the rewind target since rewinding is inclusive + Ok(Some(prev_source)) } async fn find_common_ancestor(&self) -> Result { @@ -301,6 +322,7 @@ mod tests { impl StorageRewinder for Db { fn rewind(&self, to: &BlockNumHash) -> Result<(), StorageError>; fn rewind_log_storage(&self, to: &BlockNumHash) -> Result<(), StorageError>; + fn rewind_to_source(&self, to: &BlockNumHash) -> Result, StorageError>; } ); @@ -381,14 +403,15 @@ mod tests { derived: BlockInfo::new(B256::from([3u8; 32]), 50, B256::from([4u8; 32]), 12346), }; + let canonical_source = + BlockInfo::new(B256::from([5u8; 32]), 95, B256::from([6u8; 32]), 12344); + let rewind_target_source = - BlockInfo::new(B256::from([10u8; 32]), 95, B256::from([11u8; 32]), 12340); + BlockInfo::new(B256::from([10u8; 32]), 96, B256::from([11u8; 32]), 12340); let rewind_target_derived = BlockInfo::new(B256::from([12u8; 32]), 45, B256::from([13u8; 32]), 12341); - let next_block = BlockInfo::new(B256::from([14u8; 32]), 46, B256::from([12u8; 32]), 12342); - let finalized_block = BlockInfo::new(B256::from([20u8; 32]), 40, B256::from([21u8; 32]), 12330); @@ -398,15 +421,15 @@ mod tests { // Mock finding common ancestor mock_db.expect_get_safety_head_ref().times(1).returning(move |_| Ok(finalized_block)); - mock_db.expect_derived_to_source().times(1).returning(move |_| Ok(rewind_target_source)); + mock_db.expect_derived_to_source().times(1).returning(move |_| Ok(canonical_source)); mock_db.expect_get_source_block().times(5).returning( move |block_number| match block_number { 99 => Ok(BlockInfo::new(B256::from([16u8; 32]), 99, B256::from([17u8; 32]), 12344)), 98 => Ok(BlockInfo::new(B256::from([17u8; 32]), 98, B256::from([18u8; 32]), 12343)), 97 => Ok(BlockInfo::new(B256::from([18u8; 32]), 97, B256::from([19u8; 32]), 12342)), - 96 => Ok(BlockInfo::new(B256::from([19u8; 32]), 96, B256::from([20u8; 32]), 12341)), - 95 => Ok(rewind_target_source), + 96 => Ok(rewind_target_source), + 95 => Ok(canonical_source), _ => Err(StorageError::ConflictError), }, ); @@ -439,11 +462,11 @@ mod tests { // Second call for checking if rewind target is canonical let canonical_block: Block = Block { header: Header { - hash: rewind_target_source.hash, + hash: canonical_source.hash, inner: alloy_consensus::Header { - number: rewind_target_source.number, - parent_hash: rewind_target_source.parent_hash, - timestamp: rewind_target_source.timestamp, + number: canonical_source.number, + parent_hash: canonical_source.parent_hash, + timestamp: canonical_source.timestamp, ..Default::default() }, ..Default::default() @@ -454,13 +477,9 @@ mod tests { // Mock rewind operations mock_db - .expect_latest_derived_block_at_source() + .expect_rewind_to_source() .times(1) - .returning(move |_| Ok(rewind_target_derived)); - - mock_db.expect_get_block().times(1).returning(move |_| Ok(next_block)); - - mock_db.expect_rewind().times(1).returning(|_| Ok(())); + .returning(move |_| Ok(Some(rewind_target_derived))); // Managed node should be reset managed_node.expect_reset().times(1).returning(|| Ok(())); @@ -760,7 +779,7 @@ mod tests { // Should succeed since the latest source block is still canonical assert!(rewind_target.is_ok()); - assert_eq!(rewind_target.unwrap(), Some(finalized_state.source)); + assert_eq!(rewind_target.unwrap(), Some(source_39_info)); } #[tokio::test] @@ -903,7 +922,7 @@ mod tests { // Should succeed since the latest source block is still canonical assert!(rewind_target.is_ok()); - assert_eq!(rewind_target.unwrap(), Some(activation_state.source)); + assert_eq!(rewind_target.unwrap(), Some(source_39_info)); } #[tokio::test] @@ -1104,7 +1123,7 @@ mod tests { assert!(result.is_ok()); let pair = result.unwrap().unwrap(); assert_eq!(pair.source, activation_source); - assert_eq!(pair.derived, activation_block); + assert_eq!(pair.derived.unwrap(), activation_block); } #[tokio::test] @@ -1239,24 +1258,12 @@ mod tests { let rewind_target_derived = BlockInfo::new(B256::from([3u8; 32]), 50, B256::from([4u8; 32]), 12346); - let next_block = BlockInfo::new(B256::from([5u8; 32]), 51, B256::from([3u8; 32]), 12347); - - // Expect latest_derived_block_at_source to be called + // Expect rewind to be called mock_db - .expect_latest_derived_block_at_source() + .expect_rewind_to_source() .times(1) .with(predicate::eq(rewind_target_source.id())) - .returning(move |_| Ok(rewind_target_derived)); - - // Expect get_block to be called for the next block - mock_db - .expect_get_block() - .times(1) - .with(predicate::eq(51)) - .returning(move |_| Ok(next_block)); - - // Expect rewind to be called - mock_db.expect_rewind().times(1).with(predicate::eq(next_block.id())).returning(|_| Ok(())); + .returning(move |_| Ok(Some(rewind_target_derived))); let reorg_task = ReorgTask::new( 1, @@ -1270,73 +1277,7 @@ mod tests { assert!(result.is_ok()); let pair = result.unwrap(); assert_eq!(pair.source, rewind_target_source); - assert_eq!(pair.derived, rewind_target_derived); - } - - #[tokio::test] - async fn test_rewind_to_target_source_latest_derived_block_fails() { - let mut mock_db = MockDb::new(); - let managed_node = Arc::new(MockManagedNode::new()); - - let rewind_target_source = - BlockInfo::new(B256::from([1u8; 32]), 100, B256::from([2u8; 32]), 12345); - - // Expect latest_derived_block_at_source to fail - mock_db - .expect_latest_derived_block_at_source() - .times(1) - .returning(|_| Err(StorageError::LockPoisoned)); - - let reorg_task = ReorgTask::new( - 1, - Arc::new(mock_db), - RpcClient::new(MockTransport::new(Asserter::new()), false), - managed_node, - ); - - let result = reorg_task.rewind_to_target_source(rewind_target_source).await; - - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - ReorgHandlerError::StorageError(StorageError::LockPoisoned) - )); - } - - #[tokio::test] - async fn test_rewind_to_target_source_get_block_fails() { - let mut mock_db = MockDb::new(); - let managed_node = Arc::new(MockManagedNode::new()); - - let rewind_target_source = - BlockInfo::new(B256::from([1u8; 32]), 100, B256::from([2u8; 32]), 12345); - - let rewind_target_derived = - BlockInfo::new(B256::from([3u8; 32]), 50, B256::from([4u8; 32]), 12346); - - // Expect latest_derived_block_at_source to succeed - mock_db - .expect_latest_derived_block_at_source() - .times(1) - .returning(move |_| Ok(rewind_target_derived)); - - // Expect get_block to fail - mock_db.expect_get_block().times(1).returning(|_| Err(StorageError::LockPoisoned)); - - let reorg_task = ReorgTask::new( - 1, - Arc::new(mock_db), - RpcClient::new(MockTransport::new(Asserter::new()), false), - managed_node, - ); - - let result = reorg_task.rewind_to_target_source(rewind_target_source).await; - - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - ReorgHandlerError::StorageError(StorageError::LockPoisoned) - )); + assert_eq!(pair.derived.unwrap(), rewind_target_derived); } #[tokio::test] @@ -1347,22 +1288,8 @@ mod tests { let rewind_target_source = BlockInfo::new(B256::from([1u8; 32]), 100, B256::from([2u8; 32]), 12345); - let rewind_target_derived = - BlockInfo::new(B256::from([3u8; 32]), 50, B256::from([4u8; 32]), 12346); - - let next_block = BlockInfo::new(B256::from([5u8; 32]), 51, B256::from([3u8; 32]), 12347); - - // Expect latest_derived_block_at_source to succeed - mock_db - .expect_latest_derived_block_at_source() - .times(1) - .returning(move |_| Ok(rewind_target_derived)); - - // Expect get_block to succeed - mock_db.expect_get_block().times(1).returning(move |_| Ok(next_block)); - // Expect rewind to fail - mock_db.expect_rewind().times(1).returning(|_| Err(StorageError::LockPoisoned)); + mock_db.expect_rewind_to_source().times(1).returning(|_| Err(StorageError::LockPoisoned)); let reorg_task = ReorgTask::new( 1, diff --git a/crates/supervisor/service/src/actors/processor.rs b/crates/supervisor/service/src/actors/processor.rs index d477103a2e..7b64dd0e2e 100644 --- a/crates/supervisor/service/src/actors/processor.rs +++ b/crates/supervisor/service/src/actors/processor.rs @@ -205,6 +205,7 @@ mod tests { impl StorageRewinder for Db { fn rewind_log_storage(&self, to: &BlockNumHash) -> Result<(), StorageError>; fn rewind(&self, to: &BlockNumHash) -> Result<(), StorageError>; + fn rewind_to_source(&self, to: &BlockNumHash) -> Result, StorageError>; } ); diff --git a/crates/supervisor/storage/src/chaindb.rs b/crates/supervisor/storage/src/chaindb.rs index 5236d91542..71ceedd8e3 100644 --- a/crates/supervisor/storage/src/chaindb.rs +++ b/crates/supervisor/storage/src/chaindb.rs @@ -481,6 +481,43 @@ impl StorageRewinder for ChainDb { })? }) } + + fn rewind_to_source(&self, to: &BlockNumHash) -> Result, StorageError> { + self.observe_call(Metrics::STORAGE_METHOD_REWIND_TO_SOURCE, || { + self.env.update(|tx| { + let lp = LogProvider::new(tx, self.chain_id); + let dp = DerivationProvider::new(tx, self.chain_id); + let hp = SafetyHeadRefProvider::new(tx, self.chain_id); + + let derived_target_block = dp.rewind_to_source(to)?; + if let Some(rewind_target) = derived_target_block { + lp.rewind_to(&rewind_target.id())?; + } + + // get the current latest block to update the safety head refs + match lp.get_latest_block() { + Ok(latest_block) => { + hp.reset_safety_head_ref_if_ahead(SafetyLevel::LocalUnsafe, &latest_block)?; + hp.reset_safety_head_ref_if_ahead(SafetyLevel::CrossUnsafe, &latest_block)?; + hp.reset_safety_head_ref_if_ahead(SafetyLevel::LocalSafe, &latest_block)?; + hp.reset_safety_head_ref_if_ahead(SafetyLevel::CrossSafe, &latest_block)?; + hp.reset_safety_head_ref_if_ahead(SafetyLevel::Finalized, &latest_block)?; + } + Err(StorageError::DatabaseNotInitialised) => { + // If the database returns DatabaseNotInitialised, it means we have rewound + // past the activation block + hp.remove_safety_head_ref(SafetyLevel::LocalUnsafe)?; + hp.remove_safety_head_ref(SafetyLevel::CrossUnsafe)?; + hp.remove_safety_head_ref(SafetyLevel::LocalSafe)?; + hp.remove_safety_head_ref(SafetyLevel::CrossSafe)?; + hp.remove_safety_head_ref(SafetyLevel::Finalized)?; + } + Err(err) => return Err(err), + } + Ok(derived_target_block) + })? + }) + } } impl MetricsReporter for ChainDb { @@ -1409,4 +1446,178 @@ mod tests { let latest_log_block = db.get_latest_block(); assert!(matches!(latest_log_block, Err(StorageError::DatabaseNotInitialised))); } + + #[test] + fn test_rewind_to_source_updates_logs_and_heads() { + let tmp_dir = TempDir::new().expect("create temp dir"); + let db_path = tmp_dir.path().join("chaindb_rewind_to_source"); + let db = ChainDb::new(1, &db_path).expect("create db"); + + // Anchor (activation) + let anchor = DerivedRefPair { + source: BlockInfo { + hash: B256::from([0u8; 32]), + number: 100, + parent_hash: B256::from([1u8; 32]), + timestamp: 0, + }, + derived: BlockInfo { + hash: B256::from([2u8; 32]), + number: 0, + parent_hash: B256::from([3u8; 32]), + timestamp: 0, + }, + }; + + // Initialise DB with anchor + db.initialise_log_storage(anchor.derived).expect("initialise log storage"); + db.initialise_derivation_storage(anchor).expect("initialise derivation storage"); + + // Build two source entries and several derived blocks + let source1 = BlockInfo { + hash: B256::from([3u8; 32]), + number: 101, + parent_hash: anchor.source.hash, + timestamp: 0, + }; + let source2 = BlockInfo { + hash: B256::from([4u8; 32]), + number: 102, + parent_hash: source1.hash, + timestamp: 0, + }; + + // Derived blocks chained off the anchor/previous derived blocks + let derived1 = BlockInfo { + hash: B256::from([10u8; 32]), + number: 1, + parent_hash: anchor.derived.hash, + timestamp: 0, + }; + let derived2 = BlockInfo { + hash: B256::from([11u8; 32]), + number: 2, + parent_hash: derived1.hash, + timestamp: 0, + }; + let derived3 = BlockInfo { + hash: B256::from([12u8; 32]), + number: 3, + parent_hash: derived2.hash, + timestamp: 0, + }; + let derived4 = BlockInfo { + hash: B256::from([13u8; 32]), + number: 4, + parent_hash: derived3.hash, + timestamp: 0, + }; + let derived5 = BlockInfo { + hash: B256::from([14u8; 32]), + number: 5, + parent_hash: derived4.hash, + timestamp: 0, + }; + + // Insert sources and derived blocks into storage (logs + derivation) + assert!(db.save_source_block(source1).is_ok()); + db.store_block_logs(&derived1, vec![]).expect("store logs derived1"); + db.save_derived_block(DerivedRefPair { source: source1, derived: derived1 }) + .expect("save derived1"); + + db.store_block_logs(&derived2, vec![]).expect("store logs derived2"); + db.save_derived_block(DerivedRefPair { source: source1, derived: derived2 }) + .expect("save derived2"); + + db.store_block_logs(&derived3, vec![]).expect("store logs derived3"); + db.save_derived_block(DerivedRefPair { source: source1, derived: derived3 }) + .expect("save derived3"); + + assert!(db.save_source_block(source2).is_ok()); + db.store_block_logs(&derived4, vec![]).expect("store logs derived4"); + db.save_derived_block(DerivedRefPair { source: source2, derived: derived4 }) + .expect("save derived4"); + + db.store_block_logs(&derived5, vec![]).expect("store logs derived5"); + db.save_derived_block(DerivedRefPair { source: source2, derived: derived5 }) + .expect("save derived5"); + + // Advance safety heads to be ahead of anchor so that rewind will need to reset them. + db.update_current_cross_unsafe(&derived1).expect("update cross unsafe"); + db.update_current_cross_unsafe(&derived2).expect("update cross unsafe"); + db.update_current_cross_unsafe(&derived3).expect("update cross unsafe"); + db.update_current_cross_unsafe(&derived4).expect("update cross unsafe"); + + db.update_current_cross_safe(&derived1).expect("update cross safe"); + db.update_current_cross_safe(&derived2).expect("update cross safe"); + + // Now rewind to source1: expected derived rewind target is derived1 (first derived for + // source1) + let res = db.rewind_to_source(&source1.id()).expect("rewind_to_source should succeed"); + assert!(res.is_some(), "expected a derived rewind target"); + let rewind_target = res.unwrap(); + assert_eq!(rewind_target, derived1); + + // After rewind, logs should be rewound to before derived1 -> latest block == anchor.derived + let latest_log = db.get_latest_block().expect("latest block after rewind"); + assert_eq!(latest_log, anchor.derived); + + // All safety heads that were ahead should be reset to the new latest (anchor.derived) + let local_unsafe = db.get_safety_head_ref(SafetyLevel::LocalUnsafe).expect("local unsafe"); + let cross_unsafe = db.get_safety_head_ref(SafetyLevel::CrossUnsafe).expect("cross unsafe"); + let local_safe = db.get_safety_head_ref(SafetyLevel::LocalSafe).expect("local safe"); + let cross_safe = db.get_safety_head_ref(SafetyLevel::CrossSafe).expect("cross safe"); + + assert_eq!(local_unsafe, anchor.derived); + assert_eq!(cross_unsafe, anchor.derived); + assert_eq!(local_safe, anchor.derived); + assert_eq!(cross_safe, anchor.derived); + } + + #[test] + fn test_rewind_to_source_with_empty_source_returns_none() { + let tmp_dir = TempDir::new().expect("create temp dir"); + let db_path = tmp_dir.path().join("chaindb_rewind_to_source_empty"); + let db = ChainDb::new(1, &db_path).expect("create db"); + + // Anchor (activation) + let anchor = DerivedRefPair { + source: BlockInfo { + hash: B256::from([0u8; 32]), + number: 100, + parent_hash: B256::from([1u8; 32]), + timestamp: 0, + }, + derived: BlockInfo { + hash: B256::from([2u8; 32]), + number: 0, + parent_hash: B256::from([3u8; 32]), + timestamp: 0, + }, + }; + + // Initialise DB with anchor + db.initialise_log_storage(anchor.derived).expect("initialise log storage"); + db.initialise_derivation_storage(anchor).expect("initialise derivation storage"); + + // Insert a source block that has no derived entries + let source = BlockInfo { + hash: B256::from([3u8; 32]), + number: 101, + parent_hash: anchor.source.hash, + timestamp: 0, + }; + db.save_source_block(source).expect("save source block"); + + // Rewind to the source with empty derived list -> should return None + let res = db.rewind_to_source(&source.id()).expect("rewind_to_source should succeed"); + assert!(res.is_none(), "Expected None when source has no derived blocks"); + + // Ensure latest log and derivation state remain at the anchor + let latest_log = db.get_latest_block().expect("latest block after noop rewind"); + assert_eq!(latest_log, anchor.derived); + + let latest_pair = db.latest_derivation_state().expect("latest derivation state"); + assert_eq!(latest_pair, anchor); + } } diff --git a/crates/supervisor/storage/src/metrics.rs b/crates/supervisor/storage/src/metrics.rs index a29a7264df..5aa8b0b9b1 100644 --- a/crates/supervisor/storage/src/metrics.rs +++ b/crates/supervisor/storage/src/metrics.rs @@ -42,6 +42,7 @@ impl Metrics { pub(crate) const STORAGE_METHOD_GET_FINALIZED_L1: &'static str = "get_finalized_l1"; pub(crate) const STORAGE_METHOD_REWIND_LOG_STORAGE: &'static str = "rewind_log_storage"; pub(crate) const STORAGE_METHOD_REWIND: &'static str = "rewind"; + pub(crate) const STORAGE_METHOD_REWIND_TO_SOURCE: &'static str = "rewind_to_source"; pub(crate) fn init(chain_id: ChainId) { Self::describe(); @@ -112,5 +113,6 @@ impl Metrics { Self::zero_storage_methods(chain_id, Self::STORAGE_METHOD_GET_FINALIZED_L1); Self::zero_storage_methods(chain_id, Self::STORAGE_METHOD_REWIND_LOG_STORAGE); Self::zero_storage_methods(chain_id, Self::STORAGE_METHOD_REWIND); + Self::zero_storage_methods(chain_id, Self::STORAGE_METHOD_REWIND_TO_SOURCE); } } diff --git a/crates/supervisor/storage/src/providers/derivation_provider.rs b/crates/supervisor/storage/src/providers/derivation_provider.rs index c068981cf4..f5fceb62b8 100644 --- a/crates/supervisor/storage/src/providers/derivation_provider.rs +++ b/crates/supervisor/storage/src/providers/derivation_provider.rs @@ -537,6 +537,60 @@ where Ok(()) } + + /// Rewinds the derivation storage to a specific source block. + /// This will remove all derived blocks and their traversals from the given source block onward. + /// + /// # Arguments + /// * `source` - The source block number and hash to rewind to. + /// + /// # Returns + /// [`BlockInfo`] of the derived block that was rewound to, or `None` if no derived blocks + /// were found. + pub(crate) fn rewind_to_source( + &self, + source: &BlockNumHash, + ) -> Result, StorageError> { + let mut derived_rewind_target: Option = None; + { + let mut cursor = self.tx.cursor_write::()?; + let mut walker = cursor.walk(Some(source.number))?; + while let Some(Ok((block_number, block_traversal))) = walker.next() { + if block_number == source.number && block_traversal.source.hash != source.hash { + warn!( + target: "supervisor::storage", + chain_id = %self.chain_id, + source_block_number = source.number, + expected_hash = %source.hash, + actual_hash = %block_traversal.source.hash, + "Source block hash mismatch during rewind" + ); + return Err(StorageError::ConflictError); + } + + if derived_rewind_target.is_none() && + !block_traversal.derived_block_numbers.is_empty() + { + let first_num = block_traversal.derived_block_numbers[0]; + let derived_block_pair = self.get_derived_block_pair_by_number(first_num)?; + derived_rewind_target = Some(derived_block_pair.derived.into()); + } + + walker.delete_current()?; + } + } + + // Delete all derived blocks with number ≥ `block_info.number` + if let Some(rewind_target) = derived_rewind_target { + let mut cursor = self.tx.cursor_write::()?; + let mut walker = cursor.walk(Some(rewind_target.number))?; + while let Some(Ok((_, _))) = walker.next() { + walker.delete_current()?; // we’re already walking from the rewind point + } + } + + Ok(derived_rewind_target) + } } #[cfg(test)] @@ -1158,4 +1212,106 @@ mod tests { let activation = provider.get_activation_block().expect("should exist"); assert_eq!(activation, derived1); } + + #[test] + fn rewind_to_source_returns_none_when_no_source_present() { + let db = setup_db(); + let tx = db.tx_mut().expect("Failed to get mutable tx"); + let provider = DerivationProvider::new(&tx, CHAIN_ID); + + let source = BlockNumHash { number: 9999, hash: B256::from([9u8; 32]) }; + let res = provider.rewind_to_source(&source).expect("should succeed"); + assert!(res.is_none(), "Expected None when no source traversal exists"); + } + + #[test] + fn rewind_to_source_fails_on_source_hash_mismatch() { + let db = setup_db(); + + let source1 = block_info(100, B256::from([100u8; 32]), 200); + let derived1 = block_info(0, genesis_block().hash, 200); + let pair1 = derived_pair(source1, derived1); + assert!(initialize_db(&db, &pair1).is_ok()); + + // insert a source block at number 1 with a certain hash + let source_saved = block_info(101, source1.hash, 200); + assert!(insert_source_block(&db, &source_saved).is_ok()); + + // create provider and call rewind_to_source with same number but different hash + let tx = db.tx_mut().expect("Could not get tx"); + let provider = DerivationProvider::new(&tx, CHAIN_ID); + + let mismatched_source = BlockNumHash { number: 101, hash: B256::from([42u8; 32]) }; + let result = provider.rewind_to_source(&mismatched_source); + assert!(matches!(result, Err(StorageError::ConflictError))); + } + + #[test] + fn rewind_to_source_deletes_derived_blocks_and_returns_target() { + let db = setup_db(); + + let source0 = block_info(100, B256::from([100u8; 32]), 200); + let derived0 = block_info(0, genesis_block().hash, 200); + let pair0 = derived_pair(source0, derived0); + assert!(initialize_db(&db, &pair0).is_ok()); + + // Setup source1 with derived 10,11,12 and source2 with 13,14 + let source1 = block_info(101, source0.hash, 200); + let source2 = block_info(102, source1.hash, 300); + let derived1 = block_info(1, derived0.hash, 195); + let derived2 = block_info(2, derived1.hash, 197); + let derived3 = block_info(3, derived2.hash, 290); + let derived4 = block_info(4, derived3.hash, 292); + let derived5 = block_info(5, derived4.hash, 295); + + assert!(insert_source_block(&db, &source1).is_ok()); + assert!(insert_pair(&db, &derived_pair(source1, derived1)).is_ok()); + assert!(insert_pair(&db, &derived_pair(source1, derived2)).is_ok()); + assert!(insert_pair(&db, &derived_pair(source1, derived3)).is_ok()); + + assert!(insert_source_block(&db, &source2).is_ok()); + assert!(insert_pair(&db, &derived_pair(source2, derived4)).is_ok()); + assert!(insert_pair(&db, &derived_pair(source2, derived5)).is_ok()); + + // Perform rewind_to_source starting at source1 + let tx = db.tx_mut().expect("Could not get mutable tx"); + let provider = DerivationProvider::new(&tx, CHAIN_ID); + let source_id = BlockNumHash { number: source1.number, hash: source1.hash }; + let res = provider.rewind_to_source(&source_id).expect("rewind should succeed"); + + // derived_rewind_target should be the first derived block encountered (10) + assert!(res.is_some(), "expected a derived rewind target"); + let target = res.unwrap(); + assert_eq!(target, derived1); + + let res = provider.get_derived_block_pair_by_number(10); + assert!(matches!(res, Err(StorageError::EntryNotFound(_)))); + } + + #[test] + fn rewind_to_source_with_empty_derived_list_returns_none() { + let db = setup_db(); + + let source0 = block_info(100, B256::from([100u8; 32]), 200); + let derived0 = block_info(0, genesis_block().hash, 200); + let pair0 = derived_pair(source0, derived0); + assert!(initialize_db(&db, &pair0).is_ok()); + + // Insert a source block that has no derived_block_numbers + let source1 = block_info(101, source0.hash, 200); + assert!(insert_source_block(&db, &source1).is_ok()); + + // Call rewind_to_source on that source + let tx = db.tx_mut().expect("Could not get mutable tx"); + let provider = DerivationProvider::new(&tx, CHAIN_ID); + let res = provider.rewind_to_source(&source1.id()).expect("rewind should succeed"); + + assert!(res.is_none(), "Expected None when source has empty derived list"); + + let tx = db.tx().expect("Could not get tx"); + let provider = DerivationProvider::new(&tx, CHAIN_ID); + + let activation = provider.get_activation_block().expect("activation should exist"); + assert_eq!(activation, derived0); + } } diff --git a/crates/supervisor/storage/src/traits.rs b/crates/supervisor/storage/src/traits.rs index c012e1bf2a..b7a527ccfa 100644 --- a/crates/supervisor/storage/src/traits.rs +++ b/crates/supervisor/storage/src/traits.rs @@ -452,6 +452,18 @@ pub trait StorageRewinder { /// # Errors /// Returns a [`StorageError`] if any part of the rewind process fails. fn rewind(&self, to: &BlockNumHash) -> Result<(), StorageError>; + + /// Rewinds the storage to a specific source block (inclusive), ensuring that all derived blocks + /// and logs associated with that source blocks are also reverted. + /// + /// # Arguments + /// * `to` - The source block [`BlockNumHash`] to rewind to. + /// + /// # Returns + /// * [`BlockInfo`] of the derived block that was rewound to, or `None` if no derived blocks + /// were found. + /// * `Err(StorageError)` if there is an issue during the rewind operation. + fn rewind_to_source(&self, to: &BlockNumHash) -> Result, StorageError>; } /// Combines the reader traits for the database.