Skip to content

Commit

Permalink
Move window_frame_state and partition_evaluator to datafusion_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 9, 2023
1 parent 53064e1 commit edf0afc
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 92 deletions.
7 changes: 7 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use crate::function_err::generate_signature_error_msg;
use crate::nullif::SUPPORTED_NULLIF_TYPES;
use crate::partition_evaluator::PartitionEvaluator;
use crate::type_coercion::functions::data_types;
use crate::ColumnarValue;
use crate::{
Expand Down Expand Up @@ -54,6 +55,12 @@ pub type AccumulatorFunctionImplementation =
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// Factory that creates a PartitionEvaluator for the given aggregate, given
/// its return datatype.
pub type PartitionEvaluatorFunctionFactory =
Arc<dyn Fn(&DataType) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>;


macro_rules! make_utf8_to_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ mod udwf;
pub mod utils;
pub mod window_frame;
pub mod window_function;
pub mod partition_evaluator;
pub mod window_frame_state;

pub use accumulator::Accumulator;
pub use aggregate_function::AggregateFunction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@

//! Partition evaluation module

use crate::window::window_expr::BuiltinWindowState;
use crate::window::WindowAggState;
use crate::window_frame_state::WindowAggState;
use arrow::array::ArrayRef;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, ScalarValue};
use std::any::Any;
use std::fmt::Debug;
use std::ops::Range;


/// Trait for the state managed by this partition evaluator
///
/// This follows the existing pattern, but maybe we can improve it :thinking:

pub trait PartitionState {
/// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
}

/// Partition evaluator for Window Functions
///
/// # Background
Expand Down Expand Up @@ -100,12 +111,9 @@ pub trait PartitionEvaluator: Debug + Send {
false
}

/// Returns the internal state of the window function
///
/// Only used for stateful evaluation
fn state(&self) -> Result<BuiltinWindowState> {
// If we do not use state we just return Default
Ok(BuiltinWindowState::Default)
/// Returns the internal state of the window function, if any
fn state(&self) -> Result<Option<Box<dyn PartitionState>>> {
Ok(None)
}

/// Updates the internal state for window function
Expand All @@ -130,7 +138,7 @@ pub trait PartitionEvaluator: Debug + Send {
/// Sets the internal state for window function
///
/// Only used for stateful evaluation
fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> {
fn set_state(&mut self, state: Box<dyn PartitionState>) -> Result<()> {
Err(DataFusionError::NotImplemented(
"set_state is not implemented for this window function".to_string(),
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,97 @@
//! depending on the window frame mode: RANGE, ROWS, GROUPS.

use arrow::array::ArrayRef;
use arrow::compute::{concat};
use arrow::compute::kernels::sort::SortOptions;
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::cmp::min;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;


/// State for each unique partition determined according to PARTITION BY column(s)
#[derive(Debug)]
pub struct PartitionBatchState {
/// The record_batch belonging to current partition
pub record_batch: RecordBatch,
/// Flag indicating whether we have received all data for this partition
pub is_end: bool,
/// Number of rows emitted for each partition
pub n_out_row: usize,
}


#[derive(Debug)]
pub struct WindowAggState {
/// The range that we calculate the window function
pub window_frame_range: Range<usize>,
pub window_frame_ctx: Option<WindowFrameContext>,
/// The index of the last row that its result is calculated inside the partition record batch buffer.
pub last_calculated_index: usize,
/// The offset of the deleted row number
pub offset_pruned_rows: usize,
/// Stores the results calculated by window frame
pub out_col: ArrayRef,
/// Keeps track of how many rows should be generated to be in sync with input record_batch.
// (For each row in the input record batch we need to generate a window result).
pub n_row_result_missing: usize,
/// flag indicating whether we have received all data for this partition
pub is_end: bool,
}

impl WindowAggState {
pub fn prune_state(&mut self, n_prune: usize) {
self.window_frame_range = Range {
start: self.window_frame_range.start - n_prune,
end: self.window_frame_range.end - n_prune,
};
self.last_calculated_index -= n_prune;
self.offset_pruned_rows += n_prune;

match self.window_frame_ctx.as_mut() {
// Rows have no state do nothing
Some(WindowFrameContext::Rows(_)) => {}
Some(WindowFrameContext::Range { .. }) => {}
Some(WindowFrameContext::Groups { state, .. }) => {
let mut n_group_to_del = 0;
for (_, end_idx) in &state.group_end_indices {
if n_prune < *end_idx {
break;
}
n_group_to_del += 1;
}
state.group_end_indices.drain(0..n_group_to_del);
state
.group_end_indices
.iter_mut()
.for_each(|(_, start_idx)| *start_idx -= n_prune);
state.current_group_idx -= n_group_to_del;
}
None => {}
};
}
}

impl WindowAggState {
pub fn update(
&mut self,
out_col: &ArrayRef,
partition_batch_state: &PartitionBatchState,
) -> Result<()> {
self.last_calculated_index += out_col.len();
self.out_col = concat(&[&self.out_col, &out_col])?;
self.n_row_result_missing =
partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
self.is_end = partition_batch_state.is_end;
Ok(())
}
}

/// This object stores the window frame state for use in incremental calculations.
#[derive(Debug)]
pub enum WindowFrameContext {
Expand Down Expand Up @@ -547,11 +628,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<boo

#[cfg(test)]
mod tests {
use crate::window::window_frame_state::WindowFrameStateGroups;
use super::*;
use arrow::array::{ArrayRef, Float64Array};
use arrow_schema::SortOptions;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::ops::Range;
use std::sync::Arc;

Expand Down
2 changes: 0 additions & 2 deletions datafusion/physical-expr/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ pub(crate) mod cume_dist;
pub(crate) mod lead_lag;
pub(crate) mod nth_value;
pub(crate) mod ntile;
pub(crate) mod partition_evaluator;
pub(crate) mod rank;
pub(crate) mod row_number;
mod sliding_aggregate;
mod window_expr;
mod window_frame_state;

pub use aggregate::PlainAggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
Expand Down
77 changes: 0 additions & 77 deletions datafusion/physical-expr/src/window/window_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,83 +337,6 @@ pub enum BuiltinWindowState {
Default,
}

#[derive(Debug)]
pub struct WindowAggState {
/// The range that we calculate the window function
pub window_frame_range: Range<usize>,
pub window_frame_ctx: Option<WindowFrameContext>,
/// The index of the last row that its result is calculated inside the partition record batch buffer.
pub last_calculated_index: usize,
/// The offset of the deleted row number
pub offset_pruned_rows: usize,
/// Stores the results calculated by window frame
pub out_col: ArrayRef,
/// Keeps track of how many rows should be generated to be in sync with input record_batch.
// (For each row in the input record batch we need to generate a window result).
pub n_row_result_missing: usize,
/// flag indicating whether we have received all data for this partition
pub is_end: bool,
}

impl WindowAggState {
pub fn prune_state(&mut self, n_prune: usize) {
self.window_frame_range = Range {
start: self.window_frame_range.start - n_prune,
end: self.window_frame_range.end - n_prune,
};
self.last_calculated_index -= n_prune;
self.offset_pruned_rows += n_prune;

match self.window_frame_ctx.as_mut() {
// Rows have no state do nothing
Some(WindowFrameContext::Rows(_)) => {}
Some(WindowFrameContext::Range { .. }) => {}
Some(WindowFrameContext::Groups { state, .. }) => {
let mut n_group_to_del = 0;
for (_, end_idx) in &state.group_end_indices {
if n_prune < *end_idx {
break;
}
n_group_to_del += 1;
}
state.group_end_indices.drain(0..n_group_to_del);
state
.group_end_indices
.iter_mut()
.for_each(|(_, start_idx)| *start_idx -= n_prune);
state.current_group_idx -= n_group_to_del;
}
None => {}
};
}
}

impl WindowAggState {
pub fn update(
&mut self,
out_col: &ArrayRef,
partition_batch_state: &PartitionBatchState,
) -> Result<()> {
self.last_calculated_index += out_col.len();
self.out_col = concat(&[&self.out_col, &out_col])?;
self.n_row_result_missing =
partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
self.is_end = partition_batch_state.is_end;
Ok(())
}
}

/// State for each unique partition determined according to PARTITION BY column(s)
#[derive(Debug)]
pub struct PartitionBatchState {
/// The record_batch belonging to current partition
pub record_batch: RecordBatch,
/// Flag indicating whether we have received all data for this partition
pub is_end: bool,
/// Number of rows emitted for each partition
pub n_out_row: usize,
}

/// Key for IndexMap for each unique partition
///
/// For instance, if window frame is `OVER(PARTITION BY a,b)`,
Expand Down

0 comments on commit edf0afc

Please sign in to comment.