diff --git a/unified-scheduler-pool/src/lib.rs b/unified-scheduler-pool/src/lib.rs index f9c92b9595eab0..f88a7bc04106a7 100644 --- a/unified-scheduler-pool/src/lib.rs +++ b/unified-scheduler-pool/src/lib.rs @@ -262,11 +262,11 @@ clone_trait_object!(BankingPacketHandler); pub struct BankingStageHelper { usage_queue_loader: UsageQueueLoader, next_task_id: AtomicUsize, - new_task_sender: Sender, + new_task_sender: Weak>, } impl BankingStageHelper { - fn new(new_task_sender: Sender) -> Self { + fn new(new_task_sender: Weak>) -> Self { Self { usage_queue_loader: UsageQueueLoader::default(), next_task_id: AtomicUsize::default(), @@ -296,6 +296,8 @@ impl BankingStageHelper { pub fn send_new_task(&self, task: Task) { self.new_task_sender + .upgrade() + .unwrap() .send(NewTaskPayload::Payload(task)) .unwrap(); } @@ -619,7 +621,7 @@ where fn create_handler_context( &self, mode: SchedulingMode, - new_task_sender: &Sender, + new_task_sender: &Arc>, ) -> HandlerContext { let ( thread_count, @@ -652,7 +654,9 @@ where handler_context.banking_thread_count, handler_context.banking_packet_receiver.clone(), handler_context.banking_packet_handler.clone(), - Some(Arc::new(BankingStageHelper::new(new_task_sender.clone()))), + Some(Arc::new(BankingStageHelper::new(Arc::downgrade( + new_task_sender, + )))), Some(handler_context.transaction_recorder.clone()), ) } @@ -1207,7 +1211,7 @@ where // Ensure to initiate thread shutdown via disconnected new_task_receiver by replacing the // current new_task_sender with a random one... - self.new_task_sender = crossbeam_channel::unbounded().0; + self.new_task_sender = Arc::new(crossbeam_channel::unbounded().0); self.ensure_join_threads(true); assert_matches!(self.session_result_with_timings, Some((Ok(_), _))); @@ -1241,7 +1245,11 @@ where } fn is_overgrown(&self) -> bool { - self.usage_queue_loader.count() > self.thread_manager.pool.max_usage_queue_count + self.thread_manager + .pool + .upgrade() + .map(|pool| self.usage_queue_loader.count() > pool.max_usage_queue_count) + .unwrap_or_default() } } @@ -1253,8 +1261,8 @@ where #[derive(Debug)] struct ThreadManager, TH: TaskHandler> { scheduler_id: SchedulerId, - pool: Arc>, - new_task_sender: Sender, + pool: Weak>, + new_task_sender: Arc>, new_task_receiver: Option>, session_result_sender: Sender, session_result_receiver: Receiver, @@ -1267,14 +1275,14 @@ struct HandlerPanicked; type HandlerResult = std::result::Result, HandlerPanicked>; impl, TH: TaskHandler> ThreadManager { - fn new(pool: Arc>) -> Self { + fn new(pool: &Arc>) -> Self { let (new_task_sender, new_task_receiver) = crossbeam_channel::unbounded(); let (session_result_sender, session_result_receiver) = crossbeam_channel::unbounded(); Self { scheduler_id: pool.new_scheduler_id(), - pool, - new_task_sender, + pool: Arc::downgrade(pool), + new_task_sender: Arc::new(new_task_sender), new_task_receiver: Some(new_task_receiver), session_result_sender, session_result_receiver, @@ -2170,7 +2178,7 @@ impl SpawnableScheduler for PooledScheduler { result_with_timings: ResultWithTimings, ) -> Self { let mut inner = Self::Inner { - thread_manager: ThreadManager::new(pool.clone()), + thread_manager: ThreadManager::new(&pool), usage_queue_loader: UsageQueueLoader::default(), }; inner.thread_manager.start_threads( @@ -2257,7 +2265,9 @@ where TH: TaskHandler, { fn return_to_pool(self: Box) { - self.thread_manager.pool.clone().return_scheduler(*self); + if let Some(pool) = self.thread_manager.pool.upgrade() { + pool.clone().return_scheduler(*self); + } } } @@ -3771,7 +3781,7 @@ mod tests { &task, &pool.create_handler_context( BlockVerification, - &crossbeam_channel::unbounded().0, + &Arc::new(crossbeam_channel::unbounded().0), ), ); (result, timings)