diff --git a/core/src/banking_stage/consume_worker.rs b/core/src/banking_stage/consume_worker.rs index 5446db7ecf59bd..720a4091df27d6 100644 --- a/core/src/banking_stage/consume_worker.rs +++ b/core/src/banking_stage/consume_worker.rs @@ -72,7 +72,8 @@ impl ConsumeWorker { } fn consume_loop(&self, work: ConsumeWork) -> Result<(), ConsumeWorkerError> { - let (maybe_consume_bank, get_bank_us) = measure_us!(self.working_bank_with_timeout()); + let (maybe_consume_bank, get_bank_us) = + measure_us!(self.new_working_bank_with_timeout(None)); let Some(mut bank) = maybe_consume_bank else { self.metrics .timing_metrics @@ -93,27 +94,22 @@ impl ConsumeWorker { if self.exit.load(Ordering::Relaxed) { return Ok(()); } - if bank.is_complete() || { - // if working bank has changed, then try to get a new bank. - self.working_bank() - .map(|working_bank| Arc::ptr_eq(&working_bank, &bank)) - .unwrap_or(true) - } { - let (maybe_new_bank, get_bank_us) = measure_us!(self.working_bank_with_timeout()); - if let Some(new_bank) = maybe_new_bank { - self.metrics - .timing_metrics - .wait_for_bank_success_us - .fetch_add(get_bank_us, Ordering::Relaxed); - bank = new_bank; - } else { - self.metrics - .timing_metrics - .wait_for_bank_failure_us - .fetch_add(get_bank_us, Ordering::Relaxed); - return self.retry_drain(work); - } + + // If necessary, get a new bank to consume against. + let (bank_usable, update_bank_us) = + measure_us!(self.update_working_bank_if_necessary(&mut bank)); + if !bank_usable { + self.metrics + .timing_metrics + .wait_for_bank_failure_us + .fetch_add(update_bank_us, Ordering::Relaxed); + return self.retry_drain(work); } + + self.metrics + .timing_metrics + .wait_for_bank_success_us + .fetch_add(update_bank_us, Ordering::Relaxed); self.metrics .count_metrics .num_messages_processed @@ -150,21 +146,47 @@ impl ConsumeWorker { /// Get the current poh working bank with a timeout - if the Bank is /// not available within the timeout, return None. - fn working_bank_with_timeout(&self) -> Option> { + fn new_working_bank_with_timeout(&self, current_bank: Option<&Arc>) -> Option> { const TIMEOUT: Duration = Duration::from_millis(50); let now = Instant::now(); while now.elapsed() < TIMEOUT { - if let Some(bank) = self.working_bank() { - return Some(bank); + if let Some(new_bank) = self.shared_working_bank.load_ref().as_ref() { + match current_bank { + Some(current_bank) if Arc::ptr_eq(new_bank, current_bank) => {} + // If we don't currently have a bank OR we have a new bank, return it. + _ => { + return Some(Arc::clone(new_bank)); + } + } } } None } - /// Get the current poh working bank without a timeout. - fn working_bank(&self) -> Option> { - self.shared_working_bank.load() + /// Update the bank if it has changed. + /// Returns true if the bank is updated or still usable. + fn update_working_bank_if_necessary(&self, bank: &mut Arc) -> bool { + if let Some(working_bank) = self.shared_working_bank.load_ref().as_ref() { + if !Arc::ptr_eq(working_bank, bank) { + // If we've loaded a new bank, update to it. + *bank = Arc::clone(working_bank); + } + + // If the bank, whether new or old, is still not complete, return true. + if !bank.is_complete() { + return true; + } + } + + // If `working_bank` is None or the bank is complete, we try to get the next bank. + // If this is the last leader slot in our rotation, we will timeout. + if let Some(new_bank) = self.new_working_bank_with_timeout(Some(bank)) { + *bank = new_bank; + return true; + } + + false } /// Retry current batch and all outstanding batches. diff --git a/core/src/banking_stage/decision_maker.rs b/core/src/banking_stage/decision_maker.rs index c534ccd5c698ef..722876f7009965 100644 --- a/core/src/banking_stage/decision_maker.rs +++ b/core/src/banking_stage/decision_maker.rs @@ -42,7 +42,7 @@ pub struct DecisionMaker { impl std::fmt::Debug for DecisionMaker { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("DecisionMaker") - .field("shared_working_bank", &self.shared_working_bank.load()) + .field("shared_working_bank", &self.shared_working_bank.load_full()) .field("shared_tick_height", &self.shared_tick_height.load()) .field( "shared_leader_first_tick_height", @@ -67,7 +67,7 @@ impl DecisionMaker { pub(crate) fn make_consume_or_forward_decision(&self) -> BufferedPacketsDecision { // Check if there is an active working bank. - if let Some(bank) = self.shared_working_bank.load() { + if let Some(bank) = self.shared_working_bank.load_full() { BufferedPacketsDecision::Consume(bank) } else if let Some(first_leader_tick_height) = self.shared_leader_first_tick_height.load() { let current_tick_height = self.shared_tick_height.load(); diff --git a/core/src/replay_stage.rs b/core/src/replay_stage.rs index 88bfcc5caa2ade..b67954ef672499 100644 --- a/core/src/replay_stage.rs +++ b/core/src/replay_stage.rs @@ -772,7 +772,7 @@ impl ReplayStage { // We either have a bank currently, OR there is a pending message to either reset or set // the bank. let tpu_has_bank = - shared_poh_bank.load().is_some() || poh_controller.has_pending_message(); + shared_poh_bank.load_full().is_some() || poh_controller.has_pending_message(); let mut replay_active_banks_time = Measure::start("replay_active_banks_time"); let (mut ancestors, mut descendants) = { @@ -1171,7 +1171,7 @@ impl ReplayStage { let mut dump_then_repair_correct_slots_time = Measure::start("dump_then_repair_correct_slots_time"); // Used for correctness check - let poh_bank = shared_poh_bank.load(); + let poh_bank = shared_poh_bank.load_full(); // Dump any duplicate slots that have been confirmed by the network in // anticipation of repairing the confirmed version of the slot. // diff --git a/poh/src/poh_recorder.rs b/poh/src/poh_recorder.rs index 9b990b7ae857f9..d2a0d3aa0d8a4b 100644 --- a/poh/src/poh_recorder.rs +++ b/poh/src/poh_recorder.rs @@ -1013,10 +1013,14 @@ pub fn create_test_recorder_with_index_tracking( pub struct SharedWorkingBank(Arc>); impl SharedWorkingBank { - pub fn load(&self) -> Option> { + pub fn load_full(&self) -> Option> { self.0.load_full() } + pub fn load_ref(&self) -> arc_swap::Guard>> { + self.0.load() + } + // Mutable access not needed for this function. // However we use it to guarantee only used when PohRecorder is // write locked.