From 5e42e1f90e4da3e04316c0e85500b2d278eecb51 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 7 Jun 2025 15:55:46 +0200 Subject: [PATCH 1/3] bug: remove busy-wait while sort is ongoing (#16321) --- datafusion/physical-plan/src/sorts/merge.rs | 56 +++++++++++-------- .../src/sorts/sort_preserving_merge.rs | 2 +- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 2b42457635f7b..0c18a3b6c7032 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -18,7 +18,6 @@ //! Merge that deals with an arbitrary size of streaming inputs. //! This is an order-preserving merge. -use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -143,11 +142,8 @@ pub(crate) struct SortPreservingMergeStream { /// number of rows produced produced: usize, - /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`, - /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the - /// vector to ensure the next iteration starts with a different partition, preventing the same partition - /// from being continuously polled. - uninitiated_partitions: VecDeque, + /// This vector contains the indices of the partitions that have not started emitting yet. + uninitiated_partitions: Vec, } impl SortPreservingMergeStream { @@ -216,36 +212,50 @@ impl SortPreservingMergeStream { // Once all partitions have set their corresponding cursors for the loser tree, // we skip the following block. Until then, this function may be called multiple // times and can return Poll::Pending if any partition returns Poll::Pending. + if self.loser_tree.is_empty() { - while let Some(&partition_idx) = self.uninitiated_partitions.front() { + // Manual indexing since we're iterating over the vector and shrinking it in the loop + let mut idx = 0; + while idx < self.uninitiated_partitions.len() { + let partition_idx = self.uninitiated_partitions[idx]; match self.maybe_poll_stream(cx, partition_idx) { Poll::Ready(Err(e)) => { self.aborted = true; return Poll::Ready(Some(Err(e))); } Poll::Pending => { - // If a partition returns Poll::Pending, to avoid continuously polling it - // and potentially increasing upstream buffer sizes, we move it to the - // back of the polling queue. - self.uninitiated_partitions.rotate_left(1); - - // This function could remain in a pending state, so we manually wake it here. - // However, this approach can be investigated further to find a more natural way - // to avoid disrupting the runtime scheduler. - cx.waker().wake_by_ref(); - return Poll::Pending; + // The polled stream is pending which means we're already set up to + // be woken when necessary + // Try the next stream + idx += 1; } _ => { - // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None), - // we remove this partition from the queue so it is not polled again. - self.uninitiated_partitions.pop_front(); + // The polled stream is ready + // Remove it from uninitiated_partitions + // Don't bump idx here, since a new element will have taken its + // place which we'll try in the next loop iteration + // swap_remove will change the partition poll order, but that shouldn't + // make a difference since we're waiting for all streams to be ready. + self.uninitiated_partitions.swap_remove(idx); } } } - // Claim the memory for the uninitiated partitions - self.uninitiated_partitions.shrink_to_fit(); - self.init_loser_tree(); + if self.uninitiated_partitions.is_empty() { + // If there are no more uninitiated partitions, set up the loser tree and continue + // to the next phase. + + // Claim the memory for the uninitiated partitions + self.uninitiated_partitions.shrink_to_fit(); + self.init_loser_tree(); + } else { + // There are still uninitiated partitions so return pending. + // We only get here if we've polled all uninitiated streams and at least one of them + // returned pending itself. That means we will be woken as soon as one of the + // streams would like to be polled again. + // There is no need to reschedule ourselves eagerly. + return Poll::Pending; + } } // NB timer records time taken on drop, so there are no diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 272b8f6d75e00..a9406c5c0d21d 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -1386,7 +1386,7 @@ mod tests { match self.partition { 0 => { if self.none_polled_once { - panic!("Exhausted stream is polled more than one") + panic!("Exhausted stream is polled more than once") } else { self.none_polled_once = true; Poll::Ready(None) From 12eaae991595af406ef1a8706a9ad12d2842d5be Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 8 Jun 2025 09:43:39 +0200 Subject: [PATCH 2/3] bug: make CongestedStream a correct Stream implementation --- .../src/sorts/sort_preserving_merge.rs | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index a9406c5c0d21d..b5f192c437ec8 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -381,7 +381,7 @@ mod tests { use std::fmt::Formatter; use std::pin::Pin; use std::sync::Mutex; - use std::task::{Context, Poll}; + use std::task::{ready, Context, Poll, Waker}; use std::time::Duration; use super::*; @@ -1285,13 +1285,45 @@ mod tests { "#); } + #[derive(Debug)] + struct Congestion { + congestion_cleared: Mutex>>, + } + + impl Congestion { + fn new() -> Self { + Congestion { + congestion_cleared: Mutex::new(Some(vec![])), + } + } + + fn clear_congestion(&self) { + let mut cleared = self.congestion_cleared.lock().unwrap(); + if let Some(wakers) = &mut *cleared { + wakers.iter().for_each(|w| w.wake_by_ref()); + *cleared = None; + } + } + + fn check_congested(&self, cx: &mut Context<'_>) -> Poll<()> { + let mut cleared = self.congestion_cleared.lock().unwrap(); + match &mut *cleared { + None => Poll::Ready(()), + Some(wakers) => { + wakers.push(cx.waker().clone()); + Poll::Pending + } + } + } + } + /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st /// partition is exhausted from the start, and if it is polled more than one, it panics. #[derive(Debug, Clone)] struct CongestedExec { schema: Schema, cache: PlanProperties, - congestion_cleared: Arc>, + congestion: Arc, } impl CongestedExec { @@ -1346,7 +1378,7 @@ mod tests { Ok(Box::pin(CongestedStream { schema: Arc::new(self.schema.clone()), none_polled_once: false, - congestion_cleared: Arc::clone(&self.congestion_cleared), + congestion: Arc::clone(&self.congestion), partition, })) } @@ -1373,7 +1405,7 @@ mod tests { pub struct CongestedStream { schema: SchemaRef, none_polled_once: bool, - congestion_cleared: Arc>, + congestion: Arc, partition: usize, } @@ -1381,7 +1413,7 @@ mod tests { type Item = Result; fn poll_next( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { match self.partition { 0 => { @@ -1393,16 +1425,11 @@ mod tests { } } 1 => { - let cleared = self.congestion_cleared.lock().unwrap(); - if *cleared { - Poll::Ready(None) - } else { - Poll::Pending - } + ready!(self.congestion.check_congested(cx)); + Poll::Ready(None) } 2 => { - let mut cleared = self.congestion_cleared.lock().unwrap(); - *cleared = true; + self.congestion.clear_congestion(); Poll::Ready(None) } _ => unreachable!(), @@ -1423,7 +1450,7 @@ mod tests { let source = CongestedExec { schema: schema.clone(), cache: CongestedExec::compute_properties(Arc::new(schema.clone())), - congestion_cleared: Arc::new(Mutex::new(false)), + congestion: Arc::new(Congestion::new()), }; let spm = SortPreservingMergeExec::new( [PhysicalSortExpr::new_default(Arc::new(Column::new( From c3d5ae92660ffce77f67f32fdf19efcad00efec4 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 8 Jun 2025 19:53:03 +0200 Subject: [PATCH 3/3] Make test_spm_congestion independent of the exact poll order --- .../src/sorts/sort_preserving_merge.rs | 62 +++++++++++-------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index b5f192c437ec8..2944ac230f38f 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -378,6 +378,7 @@ impl ExecutionPlan for SortPreservingMergeExec { #[cfg(test)] mod tests { + use std::collections::HashSet; use std::fmt::Formatter; use std::pin::Pin; use std::sync::Mutex; @@ -1285,34 +1286,39 @@ mod tests { "#); } + #[derive(Debug)] + struct CongestionState { + wakers: Vec, + unpolled_partitions: HashSet, + } + #[derive(Debug)] struct Congestion { - congestion_cleared: Mutex>>, + congestion_state: Mutex, } impl Congestion { - fn new() -> Self { + fn new(partition_count: usize) -> Self { Congestion { - congestion_cleared: Mutex::new(Some(vec![])), + congestion_state: Mutex::new(CongestionState { + wakers: vec![], + unpolled_partitions: (0usize..partition_count).collect(), + }), } } - fn clear_congestion(&self) { - let mut cleared = self.congestion_cleared.lock().unwrap(); - if let Some(wakers) = &mut *cleared { - wakers.iter().for_each(|w| w.wake_by_ref()); - *cleared = None; - } - } + fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> { + let mut state = self.congestion_state.lock().unwrap(); - fn check_congested(&self, cx: &mut Context<'_>) -> Poll<()> { - let mut cleared = self.congestion_cleared.lock().unwrap(); - match &mut *cleared { - None => Poll::Ready(()), - Some(wakers) => { - wakers.push(cx.waker().clone()); - Poll::Pending - } + state.unpolled_partitions.remove(&partition); + + if state.unpolled_partitions.is_empty() { + state.wakers.iter().for_each(|w| w.wake_by_ref()); + state.wakers.clear(); + Poll::Ready(()) + } else { + state.wakers.push(cx.waker().clone()); + Poll::Pending } } } @@ -1417,6 +1423,7 @@ mod tests { ) -> Poll> { match self.partition { 0 => { + let _ = self.congestion.check_congested(self.partition, cx); if self.none_polled_once { panic!("Exhausted stream is polled more than once") } else { @@ -1424,15 +1431,10 @@ mod tests { Poll::Ready(None) } } - 1 => { - ready!(self.congestion.check_congested(cx)); + _ => { + ready!(self.congestion.check_congested(self.partition, cx)); Poll::Ready(None) } - 2 => { - self.congestion.clear_congestion(); - Poll::Ready(None) - } - _ => unreachable!(), } } } @@ -1447,10 +1449,16 @@ mod tests { async fn test_spm_congestion() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); + let properties = CongestedExec::compute_properties(Arc::new(schema.clone())); + let &partition_count = match properties.output_partitioning() { + Partitioning::RoundRobinBatch(partitions) => partitions, + Partitioning::Hash(_, partitions) => partitions, + Partitioning::UnknownPartitioning(partitions) => partitions, + }; let source = CongestedExec { schema: schema.clone(), - cache: CongestedExec::compute_properties(Arc::new(schema.clone())), - congestion: Arc::new(Congestion::new()), + cache: properties, + congestion: Arc::new(Congestion::new(partition_count)), }; let spm = SortPreservingMergeExec::new( [PhysicalSortExpr::new_default(Arc::new(Column::new(