From 3c304e05194c4e14e020e27a0d00578ed3a749fb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 21 Jun 2023 02:56:05 -0400 Subject: [PATCH] Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690) * Move `PartitonEvaluator` and window_state structures to `datafusion_expr` crate * Update docs --- .../windows/bounded_window_agg_exec.rs | 4 +- datafusion/expr/src/lib.rs | 3 + .../src}/partition_evaluator.rs | 11 +- .../src/window_state.rs} | 131 +++++++++++++++--- .../physical-expr/src/window/built_in.rs | 7 +- .../window/built_in_window_function_expr.rs | 2 +- .../physical-expr/src/window/cume_dist.rs | 2 +- .../physical-expr/src/window/lead_lag.rs | 5 +- datafusion/physical-expr/src/window/mod.rs | 4 - .../physical-expr/src/window/nth_value.rs | 5 +- datafusion/physical-expr/src/window/ntile.rs | 2 +- datafusion/physical-expr/src/window/rank.rs | 5 +- .../physical-expr/src/window/row_number.rs | 2 +- .../physical-expr/src/window/window_expr.rs | 102 +------------- 14 files changed, 142 insertions(+), 143 deletions(-) rename datafusion/{physical-expr/src/window => expr/src}/partition_evaluator.rs (96%) rename datafusion/{physical-expr/src/window/window_frame_state.rs => expr/src/window_state.rs} (85%) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 2512776e8dd4..9c86abec8165 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -41,6 +41,7 @@ use arrow::{ datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use futures::stream::Stream; use futures::{ready, StreamExt}; use hashbrown::raw::RawTable; @@ -62,8 +63,7 @@ use datafusion_common::DataFusionError; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::hash_utils::create_hashes; use datafusion_physical_expr::window::{ - PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, - WindowAggState, WindowState, + PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; use datafusion_physical_expr::{ EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1675afb9c98a..ccb972887778 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -41,6 +41,7 @@ mod literal; pub mod logical_plan; mod nullif; mod operator; +mod partition_evaluator; mod signature; pub mod struct_expressions; mod table_source; @@ -51,6 +52,7 @@ mod udf; pub mod utils; pub mod window_frame; pub mod window_function; +pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; @@ -69,6 +71,7 @@ pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; +pub use partition_evaluator::PartitionEvaluator; pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs similarity index 96% rename from datafusion/physical-expr/src/window/partition_evaluator.rs rename to datafusion/expr/src/partition_evaluator.rs index e518e89a75d0..6b159d71059e 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -17,20 +17,21 @@ //! Partition evaluation module -use crate::window::WindowAggState; use arrow::array::ArrayRef; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use std::fmt::Debug; use std::ops::Range; +use crate::window_state::WindowAggState; + /// Partition evaluator for Window Functions /// /// # Background /// /// An implementation of this trait is created and used for each /// partition defined by an `OVER` clause and is instantiated by -/// [`BuiltInWindowFunctionExpr::create_evaluator`] +/// the DataFusion runtime. /// /// For example, evaluating `window_func(val) OVER (PARTITION BY col)` /// on the following data: @@ -65,7 +66,8 @@ use std::ops::Range; /// ``` /// /// Different methods on this trait will be called depending on the -/// capabilities described by [`BuiltInWindowFunctionExpr`]: +/// capabilities described by [`Self::supports_bounded_execution`], +/// [`Self::uses_window_frame`], and [`Self::include_rank`], /// /// # Stateless `PartitionEvaluator` /// @@ -95,9 +97,6 @@ use std::ops::Range; /// |false|true|`evaluate` (optionally can also implement `evaluate_all` for more optimized implementation. However, there will be default implementation that is suboptimal) . If we were to implement `ROW_NUMBER` it will end up in this quadrant. Example `OddRowNumber` showcases this use case| /// |true|false|`evaluate` (I think as long as `uses_window_frame` is `true`. There is no way for `supports_bounded_execution` to be false). I couldn't come up with any example for this quadrant | /// |true|true|`evaluate`. If we were to implement `FIRST_VALUE`, it would end up in this quadrant|. -/// -/// [`BuiltInWindowFunctionExpr`]: crate::window::BuiltInWindowFunctionExpr -/// [`BuiltInWindowFunctionExpr::create_evaluator`]: crate::window::BuiltInWindowFunctionExpr::create_evaluator pub trait PartitionEvaluator: Debug + Send { /// Updates the internal state for window function /// diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/expr/src/window_state.rs similarity index 85% rename from datafusion/physical-expr/src/window/window_frame_state.rs rename to datafusion/expr/src/window_state.rs index e23a58a09b66..09ed83a5a3a3 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -15,19 +15,100 @@ // specific language governing permissions and limitations // under the License. -//! This module provides utilities for window frame index calculations -//! depending on the window frame mode: RANGE, ROWS, GROUPS. - -use arrow::array::ArrayRef; -use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -use std::cmp::min; -use std::collections::VecDeque; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; +//! Structures used to hold window function state (for implementing WindowUDFs) + +use std::{collections::VecDeque, ops::Range, sync::Arc}; + +use arrow::{ + array::ArrayRef, + compute::{concat, SortOptions}, + datatypes::DataType, + record_batch::RecordBatch, +}; +use datafusion_common::{ + utils::{compare_rows, get_row_at_idx, search_in_slice}, + DataFusionError, Result, ScalarValue, +}; + +use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + +/// Holds the state of evaluating a window function +#[derive(Debug)] +pub struct WindowAggState { + /// The range that we calculate the window function + pub window_frame_range: Range, + pub window_frame_ctx: Option, + /// The index of the last row that its result is calculated inside the partition record batch buffer. + pub last_calculated_index: usize, + /// The offset of the deleted row number + pub offset_pruned_rows: usize, + /// Stores the results calculated by window frame + pub out_col: ArrayRef, + /// Keeps track of how many rows should be generated to be in sync with input record_batch. + // (For each row in the input record batch we need to generate a window result). + pub n_row_result_missing: usize, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +impl WindowAggState { + pub fn prune_state(&mut self, n_prune: usize) { + self.window_frame_range = Range { + start: self.window_frame_range.start - n_prune, + end: self.window_frame_range.end - n_prune, + }; + self.last_calculated_index -= n_prune; + self.offset_pruned_rows += n_prune; + + match self.window_frame_ctx.as_mut() { + // Rows have no state do nothing + Some(WindowFrameContext::Rows(_)) => {} + Some(WindowFrameContext::Range { .. }) => {} + Some(WindowFrameContext::Groups { state, .. }) => { + let mut n_group_to_del = 0; + for (_, end_idx) in &state.group_end_indices { + if n_prune < *end_idx { + break; + } + n_group_to_del += 1; + } + state.group_end_indices.drain(0..n_group_to_del); + state + .group_end_indices + .iter_mut() + .for_each(|(_, start_idx)| *start_idx -= n_prune); + state.current_group_idx -= n_group_to_del; + } + None => {} + }; + } + + pub fn update( + &mut self, + out_col: &ArrayRef, + partition_batch_state: &PartitionBatchState, + ) -> Result<()> { + self.last_calculated_index += out_col.len(); + self.out_col = concat(&[&self.out_col, &out_col])?; + self.n_row_result_missing = + partition_batch_state.record_batch.num_rows() - self.last_calculated_index; + self.is_end = partition_batch_state.is_end; + Ok(()) + } + + pub fn new(out_type: &DataType) -> Result { + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + Ok(Self { + window_frame_range: Range { start: 0, end: 0 }, + window_frame_ctx: None, + last_calculated_index: 0, + offset_pruned_rows: 0, + out_col: empty_out_col, + n_row_result_missing: 0, + is_end: false, + }) + } +} /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] @@ -125,7 +206,7 @@ impl WindowFrameContext { ))) } WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - min(idx + n as usize, length) + std::cmp::min(idx + n as usize, length) } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { @@ -150,7 +231,7 @@ impl WindowFrameContext { // UNBOUNDED FOLLOWING WindowFrameBound::Following(ScalarValue::UInt64(None)) => length, WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - min(idx + n as usize + 1, length) + std::cmp::min(idx + n as usize + 1, length) } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { @@ -161,6 +242,17 @@ impl WindowFrameContext { } } +/// State for each unique partition determined according to PARTITION BY column(s) +#[derive(Debug)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// Flag indicating whether we have received all data for this partition + pub is_end: bool, + /// Number of rows emitted for each partition + pub n_out_row: usize, +} + /// This structure encapsulates all the state information we require as we scan /// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER @@ -510,7 +602,7 @@ impl WindowFrameStateGroups { Ok(match (SIDE, SEARCH_SIDE) { // Window frame start: (true, _) => { - let group_idx = min(group_idx, self.group_end_indices.len()); + let group_idx = std::cmp::min(group_idx, self.group_end_indices.len()); if group_idx > 0 { // Normally, start at the boundary of the previous group. self.group_end_indices[group_idx - 1].1 @@ -531,7 +623,7 @@ impl WindowFrameStateGroups { } // Window frame end, FOLLOWING n (false, false) => { - let group_idx = min( + let group_idx = std::cmp::min( self.current_group_idx + delta, self.group_end_indices.len() - 1, ); @@ -547,11 +639,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result, - pub window_frame_ctx: Option, - /// The index of the last row that its result is calculated inside the partition record batch buffer. - pub last_calculated_index: usize, - /// The offset of the deleted row number - pub offset_pruned_rows: usize, - /// Stores the results calculated by window frame - pub out_col: ArrayRef, - /// Keeps track of how many rows should be generated to be in sync with input record_batch. - // (For each row in the input record batch we need to generate a window result). - pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition - pub is_end: bool, -} - -impl WindowAggState { - pub fn prune_state(&mut self, n_prune: usize) { - self.window_frame_range = Range { - start: self.window_frame_range.start - n_prune, - end: self.window_frame_range.end - n_prune, - }; - self.last_calculated_index -= n_prune; - self.offset_pruned_rows += n_prune; - - match self.window_frame_ctx.as_mut() { - // Rows have no state do nothing - Some(WindowFrameContext::Rows(_)) => {} - Some(WindowFrameContext::Range { .. }) => {} - Some(WindowFrameContext::Groups { state, .. }) => { - let mut n_group_to_del = 0; - for (_, end_idx) in &state.group_end_indices { - if n_prune < *end_idx { - break; - } - n_group_to_del += 1; - } - state.group_end_indices.drain(0..n_group_to_del); - state - .group_end_indices - .iter_mut() - .for_each(|(_, start_idx)| *start_idx -= n_prune); - state.current_group_idx -= n_group_to_del; - } - None => {} - }; - } -} - -impl WindowAggState { - pub fn update( - &mut self, - out_col: &ArrayRef, - partition_batch_state: &PartitionBatchState, - ) -> Result<()> { - self.last_calculated_index += out_col.len(); - self.out_col = concat(&[&self.out_col, &out_col])?; - self.n_row_result_missing = - partition_batch_state.record_batch.num_rows() - self.last_calculated_index; - self.is_end = partition_batch_state.is_end; - Ok(()) - } -} - -/// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug)] -pub struct PartitionBatchState { - /// The record_batch belonging to current partition - pub record_batch: RecordBatch, - /// Flag indicating whether we have received all data for this partition - pub is_end: bool, - /// Number of rows emitted for each partition - pub n_out_row: usize, -} - /// Key for IndexMap for each unique partition /// /// For instance, if window frame is `OVER(PARTITION BY a,b)`, @@ -420,18 +343,3 @@ pub type PartitionWindowAggStates = IndexMap; /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. pub type PartitionBatches = IndexMap; - -impl WindowAggState { - pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); - Ok(Self { - window_frame_range: Range { start: 0, end: 0 }, - window_frame_ctx: None, - last_calculated_index: 0, - offset_pruned_rows: 0, - out_col: empty_out_col, - n_row_result_missing: 0, - is_end: false, - }) - } -}