diff --git a/crates/engine/tree/src/tree/payload_processor/multiproof.rs b/crates/engine/tree/src/tree/payload_processor/multiproof.rs index 755f7a7d0d7..321de725bec 100644 --- a/crates/engine/tree/src/tree/payload_processor/multiproof.rs +++ b/crates/engine/tree/src/tree/payload_processor/multiproof.rs @@ -19,7 +19,10 @@ use reth_trie::{ }; use reth_trie_parallel::{ proof::ParallelProof, - proof_task::{AccountMultiproofInput, ProofResultMessage, ProofWorkerHandle}, + proof_task::{ + AccountMultiproofInput, ProofResultContext, ProofResultMessage, ProofWorkerHandle, + StorageProofInput, + }, }; use std::{collections::BTreeMap, ops::DerefMut, sync::Arc, time::Instant}; use tracing::{debug, error, instrument, trace}; @@ -408,7 +411,7 @@ impl MultiproofManager { let prefix_set = prefix_set.freeze(); // Build computation input (data only) - let input = reth_trie_parallel::proof_task::StorageProofInput::new( + let input = StorageProofInput::new( hashed_address, prefix_set, proof_targets, @@ -419,7 +422,7 @@ impl MultiproofManager { // Dispatch to storage worker if let Err(e) = self.proof_worker_handle.dispatch_storage_proof( input, - reth_trie_parallel::proof_task::ProofResultContext::new( + ProofResultContext::new( self.proof_result_tx.clone(), proof_sequence_number, hashed_state_update, @@ -492,7 +495,7 @@ impl MultiproofManager { multi_added_removed_keys, missed_leaves_storage_roots, // Workers will send ProofResultMessage directly to proof_result_rx - proof_result_sender: reth_trie_parallel::proof_task::ProofResultContext::new( + proof_result_sender: ProofResultContext::new( self.proof_result_tx.clone(), proof_sequence_number, hashed_state_update, @@ -1131,7 +1134,7 @@ impl MultiProofTask { // Convert ProofResultMessage to SparseTrieUpdate match proof_result.result { - Ok((multiproof, _stats)) => { + Ok(proof_result_data) => { debug!( target: "engine::tree::payload_processor::multiproof", sequence = proof_result.sequence_number, @@ -1141,7 +1144,7 @@ impl MultiProofTask { let update = SparseTrieUpdate { state: proof_result.state, - multiproof, + multiproof: proof_result_data.into_multiproof(), }; if let Some(combined_update) = diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 63d26993d50..4d54359d1bf 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -139,15 +139,20 @@ impl ParallelProof { ))) })?; - // Extract the multiproof from the result - let (mut multiproof, _stats) = proof_msg.result?; - - // Extract storage proof from the multiproof - let storage_proof = multiproof.storages.remove(&hashed_address).ok_or_else(|| { - ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other( - format!("storage proof not found in multiproof for {hashed_address}"), - ))) - })?; + // Extract storage proof directly from the result + let storage_proof = match proof_msg.result? { + crate::proof_task::ProofResult::StorageProof { hashed_address: addr, proof } => { + debug_assert_eq!( + addr, + hashed_address, + "storage worker must return same address: expected {hashed_address}, got {addr}" + ); + proof + } + crate::proof_task::ProofResult::AccountMultiproof { .. } => { + unreachable!("storage worker only sends StorageProof variant") + } + }; trace!( target: "trie::parallel_proof", @@ -231,7 +236,12 @@ impl ParallelProof { ) })?; - let (multiproof, stats) = proof_result_msg.result?; + let (multiproof, stats) = match proof_result_msg.result? { + crate::proof_task::ProofResult::AccountMultiproof { proof, stats } => (proof, stats), + crate::proof_task::ProofResult::StorageProof { .. } => { + unreachable!("account worker only sends AccountMultiproof variant") + } + }; #[cfg(feature = "metrics")] self.metrics.record(stats); diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index c05f2ad7286..1b50dbe73ef 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -84,8 +84,40 @@ use crate::proof_task_metrics::ProofTaskTrieMetrics; type StorageProofResult = Result; type TrieNodeProviderResult = Result, SparseTrieError>; -type AccountMultiproofResult = - Result<(DecodedMultiProof, ParallelTrieStats), ParallelStateRootError>; + +/// Result of a proof calculation, which can be either an account multiproof or a storage proof. +#[derive(Debug)] +pub enum ProofResult { + /// Account multiproof with statistics + AccountMultiproof { + /// The account multiproof + proof: DecodedMultiProof, + /// Statistics collected during proof computation + stats: ParallelTrieStats, + }, + /// Storage proof for a specific account + StorageProof { + /// The hashed address this storage proof belongs to + hashed_address: B256, + /// The storage multiproof + proof: DecodedStorageMultiProof, + }, +} + +impl ProofResult { + /// Convert this proof result into a `DecodedMultiProof`. + /// + /// For account multiproofs, returns the multiproof directly (discarding stats). + /// For storage proofs, wraps the storage proof into a minimal multiproof. + pub fn into_multiproof(self) -> DecodedMultiProof { + match self { + Self::AccountMultiproof { proof, stats: _ } => proof, + Self::StorageProof { hashed_address, proof } => { + DecodedMultiProof::from_storage_proof(hashed_address, proof) + } + } + } +} /// Channel used by worker threads to deliver `ProofResultMessage` items back to /// `MultiProofTask`. @@ -101,8 +133,8 @@ pub type ProofResultSender = CrossbeamSender; pub struct ProofResultMessage { /// Sequence number for ordering proofs pub sequence_number: u64, - /// The proof calculation result - pub result: AccountMultiproofResult, + /// The proof calculation result (either account multiproof or storage proof) + pub result: Result, /// Time taken for the entire proof calculation (from dispatch to completion) pub elapsed: Duration, /// Original state update that triggered this proof @@ -248,18 +280,10 @@ fn storage_worker_loop( let proof_elapsed = proof_start.elapsed(); storage_proofs_processed += 1; - // Convert storage proof to account multiproof format - let result_msg = match result { - Ok(storage_proof) => { - let multiproof = reth_trie::DecodedMultiProof::from_storage_proof( - hashed_address, - storage_proof, - ); - let stats = crate::stats::ParallelTrieTracker::default().finish(); - Ok((multiproof, stats)) - } - Err(e) => Err(e), - }; + let result_msg = result.map(|storage_proof| ProofResult::StorageProof { + hashed_address, + proof: storage_proof, + }); if sender .send(ProofResultMessage { @@ -496,7 +520,7 @@ fn account_worker_loop( let proof_elapsed = proof_start.elapsed(); let total_elapsed = start.elapsed(); let stats = tracker.finish(); - let result = result.map(|proof| (proof, stats)); + let result = result.map(|proof| ProofResult::AccountMultiproof { proof, stats }); account_proofs_processed += 1; // Send result to MultiProofTask @@ -657,14 +681,20 @@ where ) })?; - // Extract storage proof from the multiproof wrapper - let (mut multiproof, _stats) = proof_msg.result?; - let proof = - multiproof.storages.remove(&hashed_address).ok_or_else(|| { - ParallelStateRootError::Other(format!( - "storage proof not found in multiproof for {hashed_address}" - )) - })?; + // Extract storage proof from the result + let proof = match proof_msg.result? { + ProofResult::StorageProof { hashed_address: addr, proof } => { + debug_assert_eq!( + addr, + hashed_address, + "storage worker must return same address: expected {hashed_address}, got {addr}" + ); + proof + } + ProofResult::AccountMultiproof { .. } => { + unreachable!("storage worker only sends StorageProof variant") + } + }; let root = proof.root; collected_decoded_storages.insert(hashed_address, proof); @@ -716,10 +746,8 @@ where // Consume remaining storage proof receivers for accounts not encountered during trie walk. for (hashed_address, receiver) in storage_proof_receivers { if let Ok(proof_msg) = receiver.recv() { - // Extract storage proof from the multiproof wrapper - if let Ok((mut multiproof, _stats)) = proof_msg.result && - let Some(proof) = multiproof.storages.remove(&hashed_address) - { + // Extract storage proof from the result + if let Ok(ProofResult::StorageProof { proof, .. }) = proof_msg.result { collected_decoded_storages.insert(hashed_address, proof); } }