diff --git a/crates/supervisor/core/src/l1_watcher/watcher.rs b/crates/supervisor/core/src/l1_watcher/watcher.rs index 2063eb0441..cfe9622ffb 100644 --- a/crates/supervisor/core/src/l1_watcher/watcher.rs +++ b/crates/supervisor/core/src/l1_watcher/watcher.rs @@ -212,7 +212,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::SupervisorError; + use crate::{SupervisorError, syncnode::ManagedNodeController}; use alloy_primitives::B256; use alloy_transport::mock::*; use kona_supervisor_storage::{ChainDb, FinalizedL1Storage, StorageError}; @@ -242,7 +242,8 @@ mod tests { fn mock_reorg_handler() -> ReorgHandler { let chain_dbs_map: HashMap> = HashMap::new(); - ReorgHandler::new(mock_rpc_client(), chain_dbs_map) + let managed_nodes: HashMap> = HashMap::new(); + ReorgHandler::new(mock_rpc_client(), chain_dbs_map, managed_nodes) } #[tokio::test] diff --git a/crates/supervisor/core/src/reorg/handler.rs b/crates/supervisor/core/src/reorg/handler.rs index 37d8ad412f..574b6ab977 100644 --- a/crates/supervisor/core/src/reorg/handler.rs +++ b/crates/supervisor/core/src/reorg/handler.rs @@ -1,4 +1,4 @@ -use crate::{SupervisorError, reorg::task::ReorgTask}; +use crate::{SupervisorError, reorg::task::ReorgTask, syncnode::ManagedNodeController}; use alloy_primitives::ChainId; use alloy_rpc_client::RpcClient; use derive_more::Constructor; @@ -15,6 +15,8 @@ pub struct ReorgHandler { rpc_client: RpcClient, /// Per chain dbs. chain_dbs: HashMap>, + /// Per chain managed nodes + managed_nodes: HashMap>, } impl ReorgHandler @@ -32,8 +34,17 @@ where let mut handles = Vec::with_capacity(self.chain_dbs.len()); for (chain_id, chain_db) in &self.chain_dbs { - let reorg_task = - ReorgTask::new(*chain_id, Arc::clone(chain_db), self.rpc_client.clone()); + let managed_node = self.managed_nodes.get(chain_id).ok_or( + SupervisorError::Initialise("no managed node found for chain".to_string()), + )?; + + let reorg_task = ReorgTask::new( + *chain_id, + Arc::clone(chain_db), + self.rpc_client.clone(), + Arc::clone(managed_node), + ); + let handle = tokio::spawn(async move { reorg_task.process_chain_reorg().await }); handles.push(handle); } diff --git a/crates/supervisor/core/src/reorg/task.rs b/crates/supervisor/core/src/reorg/task.rs index cece69d00f..5209cd6e9b 100644 --- a/crates/supervisor/core/src/reorg/task.rs +++ b/crates/supervisor/core/src/reorg/task.rs @@ -1,4 +1,4 @@ -use crate::SupervisorError; +use crate::{SupervisorError, syncnode::ManagedNodeController}; use alloy_eips::BlockNumHash; use alloy_primitives::{B256, ChainId}; use alloy_rpc_client::RpcClient; @@ -14,6 +14,7 @@ pub(crate) struct ReorgTask { chain_id: ChainId, db: Arc, rpc_client: RpcClient, + managed_node: Arc, } impl ReorgTask @@ -44,6 +45,23 @@ where ); })?; + trace!( + target: "supervisor::reorg_handler", + chain_id = %self.chain_id, + "Calling resetter to reset the node after reorg" + ); + + // Reset the node after rewinding the DB. + self.managed_node.reset().await.map_err(|err| { + warn!( + target: "supervisor::reorg_handler", + chain_id = %self.chain_id, + %err, + "Failed to reset node after reorg" + ); + SupervisorError::from(err) + })?; + Ok(()) } @@ -129,14 +147,16 @@ where #[cfg(test)] mod tests { use super::*; + use crate::syncnode::{ManagedNodeController, ManagedNodeError}; use alloy_rpc_types_eth::Header; use alloy_transport::mock::*; + use async_trait::async_trait; use kona_interop::{DerivedRefPair, SafetyLevel}; use kona_protocol::BlockInfo; use kona_supervisor_storage::{ DerivationStorageReader, HeadRefStorageReader, LogStorageReader, StorageError, }; - use kona_supervisor_types::{Log, SuperHead}; + use kona_supervisor_types::{BlockSeal, Log, SuperHead}; use mockall::mock; mock!( @@ -172,6 +192,20 @@ mod tests { pub chain_db {} ); + mock! ( + #[derive(Debug)] + pub ManagedNode {} + + #[async_trait] + impl ManagedNodeController for ManagedNode { + async fn reset(&self) -> Result<(), ManagedNodeError>; + async fn update_finalized(&self, finalized_block_id: BlockNumHash) -> Result<(), ManagedNodeError>; + async fn update_cross_unsafe(&self, cross_unsafe_block_id: BlockNumHash) -> Result<(), ManagedNodeError>; + async fn update_cross_safe(&self, source_block_id: BlockNumHash, derived_block_id: BlockNumHash) -> Result<(), ManagedNodeError>; + async fn invalidate_block(&self, seal: BlockSeal) -> Result<(), ManagedNodeError>; + } + ); + #[tokio::test] async fn test_find_rewind_target_without_reorg() { let mut mock_db = MockDb::new(); @@ -208,7 +242,8 @@ mod tests { // Mock RPC response asserter.push_success(&latest_source); - let reorg_task = ReorgTask::new(1, Arc::new(mock_db), rpc_client); + let managed_node = Arc::new(MockManagedNode::new()); + let reorg_task = ReorgTask::new(1, Arc::new(mock_db), rpc_client, managed_node); let rewind_target = reorg_task.find_rewind_target().await; // Should succeed since the latest source block is still canonical @@ -342,7 +377,8 @@ mod tests { // Finally returning the correct block asserter.push_success(&finalized_source); - let reorg_task = ReorgTask::new(1, Arc::new(mock_db), rpc_client); + let managed_node = Arc::new(MockManagedNode::new()); + let reorg_task = ReorgTask::new(1, Arc::new(mock_db), rpc_client, managed_node); let rewind_target = reorg_task.find_rewind_target().await; // Should succeed since the latest source block is still canonical @@ -389,7 +425,8 @@ mod tests { asserter.push_success(&canonical_block); asserter.push_success(&non_canonical_block); - let reorg_task = ReorgTask::new(1, Arc::new(MockDb::new()), rpc_client); + let managed_node = Arc::new(MockManagedNode::new()); + let reorg_task = ReorgTask::new(1, Arc::new(MockDb::new()), rpc_client, managed_node); let result = reorg_task.is_block_canonical(100, canonical_hash).await; assert!(result.is_ok()); diff --git a/crates/supervisor/core/src/supervisor.rs b/crates/supervisor/core/src/supervisor.rs index a4d045bead..c72b6590e2 100644 --- a/crates/supervisor/core/src/supervisor.rs +++ b/crates/supervisor/core/src/supervisor.rs @@ -30,7 +30,9 @@ use crate::{ l1_watcher::L1Watcher, reorg::ReorgHandler, safety_checker::{CrossSafePromoter, CrossUnsafePromoter}, - syncnode::{Client, ManagedNode, ManagedNodeClient, ManagedNodeDataProvider}, + syncnode::{ + Client, ManagedNode, ManagedNodeClient, ManagedNodeController, ManagedNodeDataProvider, + }, }; /// Defines the service for the Supervisor core logic. @@ -293,12 +295,20 @@ impl Supervisor { .map(|chain_id| (*chain_id, self.database_factory.get_db(*chain_id).unwrap())) .collect(); + let managed_nodes = self + .managed_nodes + .iter() + .map(|(chain_id, managed_node)| { + (*chain_id, managed_node.clone() as Arc) + }) + .collect(); + let l1_watcher = L1Watcher::new( l1_rpc.clone(), self.database_factory.clone(), senders, self.cancel_token.clone(), - ReorgHandler::new(l1_rpc, chain_dbs_map), + ReorgHandler::new(l1_rpc, chain_dbs_map, managed_nodes), ); tokio::spawn(async move {