Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions datafusion/expr/src/window_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use datafusion_common::{
};

/// Holds the state of evaluating a window function
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct WindowAggState {
/// The range that we calculate the window function
pub window_frame_range: Range<usize>,
Expand Down Expand Up @@ -90,7 +90,12 @@ impl WindowAggState {
partition_batch_state: &PartitionBatchState,
) -> Result<()> {
self.last_calculated_index += out_col.len();
self.out_col = concat(&[&self.out_col, &out_col])?;
// no need to use concat if the current `out_col` is empty
if self.out_col.is_empty() {
self.out_col = Arc::clone(out_col);
} else {
self.out_col = concat(&[&self.out_col, &out_col])?;
}
self.n_row_result_missing =
partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
self.is_end = partition_batch_state.is_end;
Expand All @@ -112,7 +117,7 @@ impl WindowAggState {
}

/// This object stores the window frame state for use in incremental calculations.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum WindowFrameContext {
/// ROWS frames are inherently stateless.
Rows(Arc<WindowFrame>),
Expand Down Expand Up @@ -244,7 +249,7 @@ impl WindowFrameContext {
}

/// State for each unique partition determined according to PARTITION BY column(s)
#[derive(Debug)]
#[derive(Debug, Clone, PartialEq)]
pub struct PartitionBatchState {
/// The record batch belonging to current partition
pub record_batch: RecordBatch,
Expand All @@ -269,6 +274,15 @@ impl PartitionBatchState {
}
}

pub fn new_with_batch(batch: RecordBatch) -> Self {
Self {
record_batch: batch,
most_recent_row: None,
is_end: false,
n_out_row: 0,
}
}

pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> {
self.record_batch =
concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?;
Expand All @@ -286,7 +300,7 @@ impl PartitionBatchState {
/// ranges of data while processing RANGE frames.
/// Attribute `sort_options` stores the column ordering specified by the ORDER
/// BY clause. This information is used to calculate the range.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct WindowFrameStateRange {
sort_options: Vec<SortOptions>,
}
Expand Down Expand Up @@ -458,7 +472,7 @@ impl WindowFrameStateRange {

/// This structure encapsulates all the state information we require as we
/// scan groups of data while processing window frames.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct WindowFrameStateGroups {
/// A tuple containing group values and the row index where the group ends.
/// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to
Expand Down
8 changes: 7 additions & 1 deletion datafusion/physical-expr/src/window/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ impl WindowExpr for StandardWindowExpr {
let field = self.expr.field()?;
let out_type = field.data_type();
let sort_options = self.order_by.iter().map(|o| o.options).collect::<Vec<_>>();
// create a WindowAggState to clone when `window_agg_state` does not contain the respective
// group, which is faster than potentially creating a new one at every iteration
let new_state = WindowAggState::new(out_type)?;
for (partition_row, partition_batch_state) in partition_batches.iter() {
let window_state =
if let Some(window_state) = window_agg_state.get_mut(partition_row) {
Expand All @@ -167,7 +170,7 @@ impl WindowExpr for StandardWindowExpr {
window_agg_state
.entry(partition_row.clone())
.or_insert(WindowState {
state: WindowAggState::new(out_type)?,
state: new_state.clone(),
window_fn: WindowFn::Builtin(evaluator),
})
};
Expand Down Expand Up @@ -232,6 +235,9 @@ impl WindowExpr for StandardWindowExpr {
}
let out_col = if row_wise_results.is_empty() {
new_empty_array(out_type)
} else if row_wise_results.len() == 1 {
// fast path when the result only has a single row
row_wise_results[0].to_array()?
} else {
ScalarValue::iter_to_array(row_wise_results.into_iter())?
};
Expand Down
49 changes: 35 additions & 14 deletions datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,25 @@ trait PartitionSearcher: Send {
let partition_batches =
self.evaluate_partition_batches(&record_batch, window_expr)?;
for (partition_row, partition_batch) in partition_batches {
let partition_batch_state = partition_buffers
.entry(partition_row)
if let Some(partition_batch_state) = partition_buffers.get_mut(&partition_row)
{
partition_batch_state.extend(&partition_batch)?
} else {
let options = RecordBatchOptions::new()
.with_row_count(Some(partition_batch.num_rows()));
// Use input_schema for the buffer schema, not `record_batch.schema()`
// as it may not have the "correct" schema in terms of output
// nullability constraints. For details, see the following issue:
// https://github.com/apache/datafusion/issues/9320
.or_insert_with(|| {
PartitionBatchState::new(Arc::clone(self.input_schema()))
});
partition_batch_state.extend(&partition_batch)?;
let partition_batch = RecordBatch::try_new_with_options(
Arc::clone(self.input_schema()),
partition_batch.columns().to_vec(),
&options,
)?;
let partition_batch_state =
PartitionBatchState::new_with_batch(partition_batch);
partition_buffers.insert(partition_row, partition_batch_state);
}
}

if self.is_mode_linear() {
Expand Down Expand Up @@ -855,9 +864,11 @@ impl SortedSearch {
cur_window_expr_out_result_len
});
argmin(out_col_counts).map_or(0, |(min_idx, minima)| {
for (row, count) in counts.swap_remove(min_idx).into_iter() {
let partition_batch = &mut partition_buffers[row];
partition_batch.n_out_row = count;
let mut slowest_partition = counts.swap_remove(min_idx);
for (partition_key, partition_batch) in partition_buffers.iter_mut() {
if let Some(count) = slowest_partition.remove(partition_key) {
partition_batch.n_out_row = count;
}
}
minima
})
Expand Down Expand Up @@ -1161,6 +1172,7 @@ fn get_aggregate_result_out_column(
) -> Result<ArrayRef> {
let mut result = None;
let mut running_length = 0;
let mut batches_to_concat = vec![];
// We assume that iteration order is according to insertion order
for (
_,
Expand All @@ -1172,16 +1184,25 @@ fn get_aggregate_result_out_column(
{
if running_length < len_to_show {
let n_to_use = min(len_to_show - running_length, out_col.len());
let slice_to_use = out_col.slice(0, n_to_use);
result = Some(match result {
Some(arr) => concat(&[&arr, &slice_to_use])?,
None => slice_to_use,
});
let slice_to_use = if n_to_use == out_col.len() {
// avoid slice when the entire column is used
Arc::clone(out_col)
} else {
out_col.slice(0, n_to_use)
};
batches_to_concat.push(slice_to_use);
running_length += n_to_use;
} else {
break;
}
}

if !batches_to_concat.is_empty() {
let array_refs: Vec<&dyn Array> =
batches_to_concat.iter().map(|a| a.as_ref()).collect();
result = Some(concat(&array_refs)?);
}

if running_length != len_to_show {
return exec_err!(
"Generated row number should be {len_to_show}, it is {running_length}"
Expand Down