diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index f1d0ead23ab19..3ce333de234f9 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -34,7 +34,7 @@ use datafusion_common::{ }; /// Holds the state of evaluating a window function -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct WindowAggState { /// The range that we calculate the window function pub window_frame_range: Range, @@ -90,7 +90,12 @@ impl WindowAggState { partition_batch_state: &PartitionBatchState, ) -> Result<()> { self.last_calculated_index += out_col.len(); - self.out_col = concat(&[&self.out_col, &out_col])?; + // no need to use concat if the current `out_col` is empty + if self.out_col.is_empty() { + self.out_col = Arc::clone(out_col); + } else { + self.out_col = concat(&[&self.out_col, &out_col])?; + } self.n_row_result_missing = partition_batch_state.record_batch.num_rows() - self.last_calculated_index; self.is_end = partition_batch_state.is_end; @@ -112,7 +117,7 @@ impl WindowAggState { } /// This object stores the window frame state for use in incremental calculations. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum WindowFrameContext { /// ROWS frames are inherently stateless. Rows(Arc), @@ -244,7 +249,7 @@ impl WindowFrameContext { } /// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct PartitionBatchState { /// The record batch belonging to current partition pub record_batch: RecordBatch, @@ -269,6 +274,15 @@ impl PartitionBatchState { } } + pub fn new_with_batch(batch: RecordBatch) -> Self { + Self { + record_batch: batch, + most_recent_row: None, + is_end: false, + n_out_row: 0, + } + } + pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> { self.record_batch = concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?; @@ -286,7 +300,7 @@ impl PartitionBatchState { /// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER /// BY clause. This information is used to calculate the range. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct WindowFrameStateRange { sort_options: Vec, } @@ -458,7 +472,7 @@ impl WindowFrameStateRange { /// This structure encapsulates all the state information we require as we /// scan groups of data while processing window frames. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct WindowFrameStateGroups { /// A tuple containing group values and the row index where the group ends. /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index 22e8aea83fe78..22193094bde2c 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -158,6 +158,9 @@ impl WindowExpr for StandardWindowExpr { let field = self.expr.field()?; let out_type = field.data_type(); let sort_options = self.order_by.iter().map(|o| o.options).collect::>(); + // create a WindowAggState to clone when `window_agg_state` does not contain the respective + // group, which is faster than potentially creating a new one at every iteration + let new_state = WindowAggState::new(out_type)?; for (partition_row, partition_batch_state) in partition_batches.iter() { let window_state = if let Some(window_state) = window_agg_state.get_mut(partition_row) { @@ -167,7 +170,7 @@ impl WindowExpr for StandardWindowExpr { window_agg_state .entry(partition_row.clone()) .or_insert(WindowState { - state: WindowAggState::new(out_type)?, + state: new_state.clone(), window_fn: WindowFn::Builtin(evaluator), }) }; @@ -232,6 +235,9 @@ impl WindowExpr for StandardWindowExpr { } let out_col = if row_wise_results.is_empty() { new_empty_array(out_type) + } else if row_wise_results.len() == 1 { + // fast path when the result only has a single row + row_wise_results[0].to_array()? } else { ScalarValue::iter_to_array(row_wise_results.into_iter())? }; diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 92138bf6a7a1a..9cadcb819351c 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -414,16 +414,25 @@ trait PartitionSearcher: Send { let partition_batches = self.evaluate_partition_batches(&record_batch, window_expr)?; for (partition_row, partition_batch) in partition_batches { - let partition_batch_state = partition_buffers - .entry(partition_row) + if let Some(partition_batch_state) = partition_buffers.get_mut(&partition_row) + { + partition_batch_state.extend(&partition_batch)? + } else { + let options = RecordBatchOptions::new() + .with_row_count(Some(partition_batch.num_rows())); // Use input_schema for the buffer schema, not `record_batch.schema()` // as it may not have the "correct" schema in terms of output // nullability constraints. For details, see the following issue: // https://github.com/apache/datafusion/issues/9320 - .or_insert_with(|| { - PartitionBatchState::new(Arc::clone(self.input_schema())) - }); - partition_batch_state.extend(&partition_batch)?; + let partition_batch = RecordBatch::try_new_with_options( + Arc::clone(self.input_schema()), + partition_batch.columns().to_vec(), + &options, + )?; + let partition_batch_state = + PartitionBatchState::new_with_batch(partition_batch); + partition_buffers.insert(partition_row, partition_batch_state); + } } if self.is_mode_linear() { @@ -855,9 +864,11 @@ impl SortedSearch { cur_window_expr_out_result_len }); argmin(out_col_counts).map_or(0, |(min_idx, minima)| { - for (row, count) in counts.swap_remove(min_idx).into_iter() { - let partition_batch = &mut partition_buffers[row]; - partition_batch.n_out_row = count; + let mut slowest_partition = counts.swap_remove(min_idx); + for (partition_key, partition_batch) in partition_buffers.iter_mut() { + if let Some(count) = slowest_partition.remove(partition_key) { + partition_batch.n_out_row = count; + } } minima }) @@ -1161,6 +1172,7 @@ fn get_aggregate_result_out_column( ) -> Result { let mut result = None; let mut running_length = 0; + let mut batches_to_concat = vec![]; // We assume that iteration order is according to insertion order for ( _, @@ -1172,16 +1184,25 @@ fn get_aggregate_result_out_column( { if running_length < len_to_show { let n_to_use = min(len_to_show - running_length, out_col.len()); - let slice_to_use = out_col.slice(0, n_to_use); - result = Some(match result { - Some(arr) => concat(&[&arr, &slice_to_use])?, - None => slice_to_use, - }); + let slice_to_use = if n_to_use == out_col.len() { + // avoid slice when the entire column is used + Arc::clone(out_col) + } else { + out_col.slice(0, n_to_use) + }; + batches_to_concat.push(slice_to_use); running_length += n_to_use; } else { break; } } + + if !batches_to_concat.is_empty() { + let array_refs: Vec<&dyn Array> = + batches_to_concat.iter().map(|a| a.as_ref()).collect(); + result = Some(concat(&array_refs)?); + } + if running_length != len_to_show { return exec_err!( "Generated row number should be {len_to_show}, it is {running_length}"