From d53d6c407415e532dcad6e27fc664a41a99df496 Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Tue, 13 Jun 2023 09:56:54 +1000 Subject: [PATCH 1/6] feat: support sliding window accumulators Rationale: The default implementation of the `Accumulator` trait returns an error for the `retract_batch` API. --- datafusion/core/src/physical_plan/udaf.rs | 4 ++++ datafusion/core/src/physical_plan/windows/mod.rs | 15 ++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs index d9f52eba77d0..73a8f5076572 100644 --- a/datafusion/core/src/physical_plan/udaf.rs +++ b/datafusion/core/src/physical_plan/udaf.rs @@ -106,6 +106,10 @@ impl AggregateExpr for AggregateFunctionExpr { (self.fun.accumulator)(&self.data_type) } + fn create_sliding_accumulator(&self) -> Result> { + (self.fun.accumulator)(&self.data_type) + } + fn name(&self) -> &str { &self.name } diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index a43ada82ee24..8900ab75ee27 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -65,9 +65,12 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { - let aggregate = - aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?; + WindowFunction::AggregateFunction(_) | WindowFunction::AggregateUDF(_) => { + let aggregate = match fun { + WindowFunction::AggregateFunction(fun) => aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?, + WindowFunction::AggregateUDF(fun) => udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?, + _ => unreachable!() + }; if !window_frame.start_bound.is_unbounded() { Arc::new(SlidingAggregateWindowExpr::new( aggregate, @@ -90,12 +93,6 @@ pub fn create_window_expr( order_by, window_frame, )), - WindowFunction::AggregateUDF(fun) => Arc::new(PlainAggregateWindowExpr::new( - udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), }) } From 579b4d9915da35a3e477688e258e1e29314f94ef Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 14 Jun 2023 17:13:43 -0400 Subject: [PATCH 2/6] Allow AggregateUDF to define retractable batch --- datafusion/core/src/physical_plan/udaf.rs | 20 ++- .../core/src/physical_plan/windows/mod.rs | 34 ++++- .../core/tests/user_defined_aggregates.rs | 140 ++++++++++++++---- datafusion/expr/src/accumulator.rs | 19 ++- datafusion/expr/src/udaf.rs | 13 +- datafusion/proto/src/physical_plan/mod.rs | 1 + 6 files changed, 183 insertions(+), 44 deletions(-) diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs index 73a8f5076572..73c7c8030d54 100644 --- a/datafusion/core/src/physical_plan/udaf.rs +++ b/datafusion/core/src/physical_plan/udaf.rs @@ -28,7 +28,7 @@ use arrow::{ use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use crate::physical_plan::PhysicalExpr; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; pub use datafusion_expr::AggregateUDF; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; @@ -41,7 +41,7 @@ pub fn create_aggregate_expr( input_phy_exprs: &[Arc], input_schema: &Schema, name: impl Into, -) -> Result> { +) -> Result> { let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(input_schema)) @@ -70,6 +70,11 @@ impl AggregateFunctionExpr { pub fn fun(&self) -> &AggregateUDF { &self.fun } + + /// Returns true if this can support sliding accumulators + pub fn retractable(&self) -> Result { + Ok((self.fun.accumulator)(&self.data_type)?.supports_retract_batch()) + } } impl AggregateExpr for AggregateFunctionExpr { @@ -107,7 +112,16 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_sliding_accumulator(&self) -> Result> { - (self.fun.accumulator)(&self.data_type) + let accumulator = (self.fun.accumulator)(&self.data_type)?; + + if !accumulator.supports_retract_batch() { + return Err(DataFusionError::Internal( + format!( + "Can't make sliding accumulator because retractable_accumulator not available for {}", + self.name) + )); + } + Ok(accumulator) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 8900ab75ee27..5c2f097ce692 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -64,14 +64,14 @@ pub fn create_window_expr( window_frame: Arc, input_schema: &Schema, ) -> Result> { + // Is there a potentially unlimited sized window frame? + let unbounded_window = window_frame.start_bound.is_unbounded(); + Ok(match fun { - WindowFunction::AggregateFunction(_) | WindowFunction::AggregateUDF(_) => { - let aggregate = match fun { - WindowFunction::AggregateFunction(fun) => aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?, - WindowFunction::AggregateUDF(fun) => udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?, - _ => unreachable!() - }; - if !window_frame.start_bound.is_unbounded() { + WindowFunction::AggregateFunction(fun) => { + let aggregate = + aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?; + if !unbounded_window { Arc::new(SlidingAggregateWindowExpr::new( aggregate, partition_by, @@ -93,6 +93,26 @@ pub fn create_window_expr( order_by, window_frame, )), + WindowFunction::AggregateUDF(fun) => { + let aggregate = + udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; + + if !unbounded_window && aggregate.retractable()? { + Arc::new(SlidingAggregateWindowExpr::new( + aggregate, + partition_by, + order_by, + window_frame, + )) + } else { + Arc::new(PlainAggregateWindowExpr::new( + aggregate, + partition_by, + order_by, + window_frame, + )) + } + } }) } diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs index 7c95b9a2d49a..633d2cc62f18 100644 --- a/datafusion/core/tests/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined_aggregates.rs @@ -40,13 +40,32 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::cast::as_primitive_array; +use datafusion_common::{cast::as_primitive_array, DataFusionError}; + +/// Test to show the contents of the setup +#[tokio::test] +async fn test_setup() { + let TestContext { ctx, test_state: _ } = TestContext::new(); + let sql = "SELECT * from t order by time"; + let expected = vec![ + "+-------+----------------------------+", + "| value | time |", + "+-------+----------------------------+", + "| 2.0 | 1970-01-01T00:00:00.000002 |", + "| 3.0 | 1970-01-01T00:00:00.000003 |", + "| 1.0 | 1970-01-01T00:00:00.000004 |", + "| 5.0 | 1970-01-01T00:00:00.000005 |", + "| 5.0 | 1970-01-01T00:00:00.000005 |", + "+-------+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await); +} /// Basic user defined aggregate #[tokio::test] async fn test_udaf() { - let TestContext { ctx, counters } = TestContext::new(); - assert!(!counters.update_batch()); + let TestContext { ctx, test_state } = TestContext::new(); + assert!(!test_state.update_batch()); let sql = "SELECT time_sum(time) from t"; let expected = vec![ "+----------------------------+", @@ -57,14 +76,14 @@ async fn test_udaf() { ]; assert_batches_eq!(expected, &execute(&ctx, sql).await); // normal aggregates call update_batch - assert!(counters.update_batch()); - assert!(!counters.retract_batch()); + assert!(test_state.update_batch()); + assert!(!test_state.retract_batch()); } /// User defined aggregate used as a window function #[tokio::test] async fn test_udaf_as_window() { - let TestContext { ctx, counters } = TestContext::new(); + let TestContext { ctx, test_state } = TestContext::new(); let sql = "SELECT time_sum(time) OVER() as time_sum from t"; let expected = vec![ "+----------------------------+", @@ -79,15 +98,41 @@ async fn test_udaf_as_window() { ]; assert_batches_eq!(expected, &execute(&ctx, sql).await); // aggregate over the entire window function call update_batch - assert!(counters.update_batch()); - assert!(!counters.retract_batch()); + assert!(test_state.update_batch()); + assert!(!test_state.retract_batch()); } /// User defined aggregate used as a window function with a window frame #[tokio::test] async fn test_udaf_as_window_with_frame() { - let TestContext { ctx, counters } = TestContext::new(); + let TestContext { ctx, test_state } = TestContext::new(); + let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; + let expected = vec![ + "+----------------------------+", + "| time_sum |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000005 |", + "| 1970-01-01T00:00:00.000009 |", + "| 1970-01-01T00:00:00.000012 |", + "| 1970-01-01T00:00:00.000014 |", + "| 1970-01-01T00:00:00.000010 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await); + // user defined aggregates with window frame should be calling retract batch + assert!(test_state.update_batch()); + assert!(test_state.retract_batch()); +} + +/// Ensure that User defined aggregate used as a window function with a window +/// frame, but that does not implement retract_batch, does not error +#[tokio::test] +async fn test_udaf_as_window_with_frame_without_retract_batch() { + let test_state = Arc::new(TestState::new().with_error_on_retract_batch()); + + let TestContext { ctx, test_state } = TestContext::new_with_test_state(test_state); let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; + // TODO: It is not clear why this is a different value than when retract batch is used let expected = vec![ "+----------------------------+", "| time_sum |", @@ -100,16 +145,14 @@ async fn test_udaf_as_window_with_frame() { "+----------------------------+", ]; assert_batches_eq!(expected, &execute(&ctx, sql).await); - // user defined aggregates with window frame should be calling retract batch - // but doesn't yet: https://github.com/apache/arrow-datafusion/issues/6611 - assert!(counters.update_batch()); - assert!(!counters.retract_batch()); + assert!(test_state.update_batch()); + assert!(!test_state.retract_batch()); } /// Basic query for with a udaf returning a structure #[tokio::test] async fn test_udaf_returning_struct() { - let TestContext { ctx, counters: _ } = TestContext::new(); + let TestContext { ctx, test_state: _ } = TestContext::new(); let sql = "SELECT first(value, time) from t"; let expected = vec![ "+------------------------------------------------+", @@ -124,7 +167,7 @@ async fn test_udaf_returning_struct() { /// Demonstrate extracting the fields from a structure using a subquery #[tokio::test] async fn test_udaf_returning_struct_subquery() { - let TestContext { ctx, counters: _ } = TestContext::new(); + let TestContext { ctx, test_state: _ } = TestContext::new(); let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq"; let expected = vec![ "+-----------------+----------------------------+", @@ -155,13 +198,16 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec { /// ``` struct TestContext { ctx: SessionContext, - counters: Arc, + test_state: Arc, } impl TestContext { fn new() -> Self { - let counters = Arc::new(TestCounters::new()); + let test_state = Arc::new(TestState::new()); + Self::new_with_test_state(test_state) + } + fn new_with_test_state(test_state: Arc) -> Self { let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]); let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]); @@ -178,21 +224,24 @@ impl TestContext { // Tell DataFusion about the "first" function FirstSelector::register(&mut ctx); // Tell DataFusion about the "time_sum" function - TimeSum::register(&mut ctx, Arc::clone(&counters)); + TimeSum::register(&mut ctx, Arc::clone(&test_state)); - Self { ctx, counters } + Self { ctx, test_state } } } #[derive(Debug, Default)] -struct TestCounters { +struct TestState { /// was update_batch called? update_batch: AtomicBool, /// was retract_batch called? retract_batch: AtomicBool, + /// should the udaf throw an error if retract batch is called? Can + /// only be configured at construction time. + error_on_retract_batch: bool, } -impl TestCounters { +impl TestState { fn new() -> Self { Default::default() } @@ -202,10 +251,31 @@ impl TestCounters { self.update_batch.load(Ordering::SeqCst) } + /// Set the `update_batch` flag + fn set_update_batch(&self) { + self.update_batch.store(true, Ordering::SeqCst) + } + /// Has `retract_batch` been called? fn retract_batch(&self) -> bool { self.retract_batch.load(Ordering::SeqCst) } + + /// set the `retract_batch` flag + fn set_retract_batch(&self) { + self.retract_batch.store(true, Ordering::SeqCst) + } + + /// Is this state configured to return an error on retract batch? + fn error_on_retract_batch(&self) -> bool { + self.error_on_retract_batch + } + + /// Configure the test to return error on retract batch + fn with_error_on_retract_batch(mut self) -> Self { + self.error_on_retract_batch = true; + self + } } /// Models a user defined aggregate function that computes the a sum @@ -213,15 +283,15 @@ impl TestCounters { #[derive(Debug)] struct TimeSum { sum: i64, - counters: Arc, + test_state: Arc, } impl TimeSum { - fn new(counters: Arc) -> Self { - Self { sum: 0, counters } + fn new(test_state: Arc) -> Self { + Self { sum: 0, test_state } } - fn register(ctx: &mut SessionContext, counters: Arc) { + fn register(ctx: &mut SessionContext, test_state: Arc) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); // Returns the same type as its input @@ -237,8 +307,9 @@ impl TimeSum { let signature = Signature::exact(vec![timestamp_type], volatility); + let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFunctionImplementation = - Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&counters))))); + Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); let name = "time_sum"; @@ -256,12 +327,13 @@ impl Accumulator for TimeSum { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.counters.update_batch.store(true, Ordering::SeqCst); + self.test_state.set_update_batch(); assert_eq!(values.len(), 1); let arr = &values[0]; let arr = arr.as_primitive::(); for v in arr.values().iter() { + println!("Adding {v}"); self.sum += v; } Ok(()) @@ -273,6 +345,7 @@ impl Accumulator for TimeSum { } fn evaluate(&self) -> Result { + println!("Evaluating to {}", self.sum); Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None)) } @@ -282,16 +355,27 @@ impl Accumulator for TimeSum { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.counters.retract_batch.store(true, Ordering::SeqCst); + if self.test_state.error_on_retract_batch() { + return Err(DataFusionError::Execution( + "Error in Retract Batch".to_string(), + )); + } + + self.test_state.set_retract_batch(); assert_eq!(values.len(), 1); let arr = &values[0]; let arr = arr.as_primitive::(); for v in arr.values().iter() { + println!("Retracting {v}"); self.sum -= v; } Ok(()) } + + fn supports_retract_batch(&self) -> bool { + !self.test_state.error_on_retract_batch() + } } /// Models a specialized timeseries aggregate function diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 7e941d0cff97..c448ed423530 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -21,12 +21,15 @@ use arrow::array::ArrayRef; use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::fmt::Debug; -/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and -/// generically accumulates values. +/// Accumulates an aggregate's state. +/// +/// `Accumulator`s are stateful objects that lives throughout the +/// evaluation of multiple rows and aggregate multiple values together +/// into a final output aggregate. /// /// An accumulator knows how to: /// * update its state from inputs via `update_batch` -/// * retract an update to its state from given inputs via `retract_batch` +/// * (optionally) retract an update to its state from given inputs via `retract_batch` /// * convert its internal state to a vector of aggregate values /// * update its state from multiple accumulators' states via `merge_batch` /// * compute the final value from its internal state via `evaluate` @@ -68,6 +71,16 @@ pub trait Accumulator: Send + Sync + Debug { )) } + /// Does the accumulator support incrementally updating its value + /// by *removing* values. + /// + /// If this function returns true, [`Self::retract_batch`] will be + /// called for sliding window functions such as queries with an + /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)` + fn supports_retract_batch(&self) -> bool { + false + } + /// Updates the accumulator's state from an `Array` containing one /// or more intermediate values. /// diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 6c3690e283d2..1b455a098539 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,13 +24,20 @@ use crate::{ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; -/// Logical representation of a user-defined aggregate function (UDAF) -/// A UDAF is different from a UDF in that it is stateful across batches. +/// Logical representation of a user-defined aggregate function (UDAF). +/// +/// A UDAF is different from a user-defined scalar function (UDF) in +/// that it is stateful across batches. UDAFs can be used as normal +/// aggregate functions as well as window functions (the `OVER` clause) +/// +/// For more information, please see [the examples] +/// +/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process #[derive(Clone)] pub struct AggregateUDF { /// name pub name: String, - /// signature + /// Signature (input arguments) pub signature: Signature, /// Return type pub return_type: ReturnTypeFunction, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 3c14981355ec..fd9393afab0f 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -465,6 +465,7 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) + .map(|func| func as Arc) } } }).transpose()?.ok_or_else(|| { From ce89853b7ac9c278004ca4940cd146f1ecda7fa3 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 15 Jun 2023 10:12:23 +0300 Subject: [PATCH 3/6] Replace supports_bounded_execution with supports_retract_batch --- .../physical-expr/src/aggregate/average.rs | 7 +++---- datafusion/physical-expr/src/aggregate/count.rs | 8 ++++---- .../physical-expr/src/aggregate/min_max.rs | 16 ++++++++-------- datafusion/physical-expr/src/aggregate/mod.rs | 6 ------ datafusion/physical-expr/src/aggregate/sum.rs | 8 ++++---- datafusion/physical-expr/src/window/aggregate.rs | 7 +++++-- .../src/window/sliding_aggregate.rs | 7 +++++-- 7 files changed, 29 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 607572862290..3c76da51a9d4 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -134,10 +134,6 @@ impl AggregateExpr for Avg { is_row_accumulator_support_dtype(&self.sum_data_type) } - fn supports_bounded_execution(&self) -> bool { - true - } - fn create_row_accumulator( &self, start_index: usize, @@ -263,6 +259,9 @@ impl Accumulator for AvgAccumulator { )), } } + fn supports_retract_batch(&self) -> bool { + true + } fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 15df28b4e38a..22cb2512fc42 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -133,10 +133,6 @@ impl AggregateExpr for Count { true } - fn supports_bounded_execution(&self) -> bool { - true - } - fn create_row_accumulator( &self, start_index: usize, @@ -214,6 +210,10 @@ impl Accumulator for CountAccumulator { Ok(ScalarValue::Int64(Some(self.count))) } + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { std::mem::size_of_val(self) } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index f811dae7b560..e3c061dc1354 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -125,10 +125,6 @@ impl AggregateExpr for Max { is_row_accumulator_support_dtype(&self.data_type) } - fn supports_bounded_execution(&self) -> bool { - true - } - fn create_row_accumulator( &self, start_index: usize, @@ -699,6 +695,10 @@ impl Accumulator for SlidingMaxAccumulator { Ok(self.max.clone()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() } @@ -825,10 +825,6 @@ impl AggregateExpr for Min { is_row_accumulator_support_dtype(&self.data_type) } - fn supports_bounded_execution(&self) -> bool { - true - } - fn create_row_accumulator( &self, start_index: usize, @@ -958,6 +954,10 @@ impl Accumulator for SlidingMinAccumulator { Ok(self.min.clone()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 09fd9bcfc524..7d2316c532a0 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -96,12 +96,6 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { false } - /// Specifies whether this aggregate function can run using bounded memory. - /// Any accumulator returning "true" needs to implement `retract_batch`. - fn supports_bounded_execution(&self) -> bool { - false - } - /// RowAccumulator to access/update row-based aggregation state in-place. /// Currently, row accumulator only supports states of fixed-sized type. /// diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 1c70dc67beeb..efa55f060264 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -131,10 +131,6 @@ impl AggregateExpr for Sum { is_row_accumulator_support_dtype(&self.data_type) } - fn supports_bounded_execution(&self) -> bool { - true - } - fn create_row_accumulator( &self, start_index: usize, @@ -361,6 +357,10 @@ impl Accumulator for SumAccumulator { } } + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index c8a4797a5288..9a1d1f91f2ec 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -155,8 +155,11 @@ impl WindowExpr for PlainAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - self.aggregate.supports_bounded_execution() - && !self.window_frame.end_bound.is_unbounded() + if let Ok(acc) = self.aggregate.create_sliding_accumulator() { + acc.supports_retract_batch() && !self.window_frame.end_bound.is_unbounded() + } else { + false + } } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 709f8d23be36..e849c00d5852 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -139,8 +139,11 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - self.aggregate.supports_bounded_execution() - && !self.window_frame.end_bound.is_unbounded() + if let Ok(acc) = self.aggregate.create_sliding_accumulator() { + acc.supports_retract_batch() && !self.window_frame.end_bound.is_unbounded() + } else { + false + } } } From fbc978e962f446617f9d2595adca63e6d103cb73 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 16 Jun 2023 09:48:45 +0300 Subject: [PATCH 4/6] simplifications --- datafusion/core/src/physical_plan/udaf.rs | 7 +------ datafusion/proto/src/physical_plan/mod.rs | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs index 9266c48ffd77..bca9eb878297 100644 --- a/datafusion/core/src/physical_plan/udaf.rs +++ b/datafusion/core/src/physical_plan/udaf.rs @@ -41,7 +41,7 @@ pub fn create_aggregate_expr( input_phy_exprs: &[Arc], input_schema: &Schema, name: impl Into, -) -> Result> { +) -> Result> { let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(input_schema)) @@ -70,11 +70,6 @@ impl AggregateFunctionExpr { pub fn fun(&self) -> &AggregateUDF { &self.fun } - - /// Returns true if this can support sliding accumulators - pub fn retractable(&self) -> Result { - Ok((self.fun.accumulator)(&self.data_type)?.supports_retract_batch()) - } } impl AggregateExpr for AggregateFunctionExpr { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index fd9393afab0f..3c14981355ec 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -465,7 +465,6 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) - .map(|func| func as Arc) } } }).transpose()?.ok_or_else(|| { From 65d00e42df15154159fb301ca2285248db508bb2 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 16 Jun 2023 10:22:45 +0300 Subject: [PATCH 5/6] simplifications --- datafusion/physical-expr/src/window/aggregate.rs | 6 +----- datafusion/physical-expr/src/window/sliding_aggregate.rs | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 9a1d1f91f2ec..5892f7f3f3b0 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -155,11 +155,7 @@ impl WindowExpr for PlainAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - if let Ok(acc) = self.aggregate.create_sliding_accumulator() { - acc.supports_retract_batch() && !self.window_frame.end_bound.is_unbounded() - } else { - false - } + !self.window_frame.end_bound.is_unbounded() } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index e849c00d5852..1494129cf897 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -139,11 +139,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - if let Ok(acc) = self.aggregate.create_sliding_accumulator() { - acc.supports_retract_batch() && !self.window_frame.end_bound.is_unbounded() - } else { - false - } + !self.window_frame.end_bound.is_unbounded() } } From 9a1bfa6a5adbb79363e9ab659c1b269382a3b298 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 16 Jun 2023 17:18:23 +0300 Subject: [PATCH 6/6] Rename evalaute_with_rank_all --- datafusion/physical-expr/src/window/built_in.rs | 2 +- datafusion/physical-expr/src/window/cume_dist.rs | 4 ++-- .../physical-expr/src/window/partition_evaluator.rs | 8 ++++---- datafusion/physical-expr/src/window/rank.rs | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index a03267c03532..828bc7218fa2 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -122,7 +122,7 @@ impl WindowExpr for BuiltInWindowExpr { } else if evaluator.include_rank() { let columns = self.sort_columns(batch)?; let sort_partition_points = evaluate_partition_ranges(num_rows, &columns)?; - evaluator.evaluate_with_rank_all(num_rows, &sort_partition_points) + evaluator.evaluate_all_with_rank(num_rows, &sort_partition_points) } else { let (values, _) = self.get_values_orderbys(batch)?; evaluator.evaluate_all(&values, num_rows) diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 47f2e4208d71..9040165ac9e0 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -70,7 +70,7 @@ impl BuiltInWindowFunctionExpr for CumeDist { pub(crate) struct CumeDistEvaluator; impl PartitionEvaluator for CumeDistEvaluator { - fn evaluate_with_rank_all( + fn evaluate_all_with_rank( &self, num_rows: usize, ranks_in_partition: &[Range], @@ -109,7 +109,7 @@ mod tests { ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank_all(num_rows, &ranks)?; + .evaluate_all_with_rank(num_rows, &ranks)?; let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 0dfad0e80f05..e518e89a75d0 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -69,7 +69,7 @@ use std::ops::Range; /// /// # Stateless `PartitionEvaluator` /// -/// In this case, either [`Self::evaluate_all`] or [`Self::evaluate_with_rank_all`] is called with values for the +/// In this case, either [`Self::evaluate_all`] or [`Self::evaluate_all_with_rank`] is called with values for the /// entire partition. /// /// # Stateful `PartitionEvaluator` @@ -221,7 +221,7 @@ pub trait PartitionEvaluator: Debug + Send { )) } - /// [`PartitionEvaluator::evaluate_with_rank_all`] is called for window + /// [`PartitionEvaluator::evaluate_all_with_rank`] is called for window /// functions that only need the rank of a row within its window /// frame. /// @@ -248,7 +248,7 @@ pub trait PartitionEvaluator: Debug + Send { /// (3,4), /// ] /// ``` - fn evaluate_with_rank_all( + fn evaluate_all_with_rank( &self, _num_rows: usize, _ranks_in_partition: &[Range], @@ -278,7 +278,7 @@ pub trait PartitionEvaluator: Debug + Send { /// Can this function be evaluated with (only) rank /// - /// If `include_rank` is true, implement [`PartitionEvaluator::evaluate_with_rank_all`] + /// If `include_rank` is true, implement [`PartitionEvaluator::evaluate_all_with_rank`] fn include_rank(&self) -> bool { false } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index be184ca891de..59a08358cda6 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -159,7 +159,7 @@ impl PartitionEvaluator for RankEvaluator { } } - fn evaluate_with_rank_all( + fn evaluate_all_with_rank( &self, num_rows: usize, ranks_in_partition: &[Range], @@ -236,7 +236,7 @@ mod tests { ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank_all(num_rows, &ranks)?; + .evaluate_all_with_rank(num_rows, &ranks)?; let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); @@ -248,7 +248,7 @@ mod tests { ranks: Vec>, expected: Vec, ) -> Result<()> { - let result = expr.create_evaluator()?.evaluate_with_rank_all(8, &ranks)?; + let result = expr.create_evaluator()?.evaluate_all_with_rank(8, &ranks)?; let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, *result);