From f527e13aff2e0c200e41429bcf3e4ae3ae653922 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Thu, 6 Jun 2024 20:39:35 +0800 Subject: [PATCH 01/11] expr builder Signed-off-by: jayzhan211 --- datafusion-examples/examples/udaf_expr.rs | 44 ++++++++++++++++ datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 51 ++++++++++++++++++- .../functions-aggregate/src/first_last.rs | 19 ++++--- datafusion/functions-aggregate/src/macros.rs | 32 +++--------- .../src/replace_distinct_aggregate.rs | 18 +++---- .../tests/cases/roundtrip_logical_plan.rs | 3 +- docs/source/user-guide/expressions.md | 10 ++++ 8 files changed, 130 insertions(+), 49 deletions(-) create mode 100644 datafusion-examples/examples/udaf_expr.rs diff --git a/datafusion-examples/examples/udaf_expr.rs b/datafusion-examples/examples/udaf_expr.rs new file mode 100644 index 000000000000..d655ad44160a --- /dev/null +++ b/datafusion-examples/examples/udaf_expr.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + execution::{ + config::SessionConfig, + context::{SessionContext, SessionState}, + }, + functions_aggregate::{expr_fn::first_value, register_all}, +}; + +use datafusion_common::Result; +use datafusion_expr::{col, AggregateUDFExprBuilder}; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + let config = SessionConfig::new(); + let mut state = SessionState::new_with_config_rt(config, ctx.runtime_env()); + let _ = register_all(&mut state); + + let first_value_udaf = state.aggregate_functions().get("first_value").unwrap(); + let first_value_builder = first_value_udaf + .call(vec![col("a")]) + .order_by(vec![col("b")]); + + let first_value_fn = first_value(col("a"), Some(vec![col("b")])); + assert_eq!(first_value_builder, first_value_fn); + Ok(()) +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 8c9893b8a748..8edd982261cd 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -81,7 +81,7 @@ pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, AggregateUDFExprBuilder}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d778203207c9..b5944df5805f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use crate::expr::AggregateFunction; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; @@ -27,6 +28,7 @@ use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, not_impl_err, Result}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -139,8 +141,7 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - // TODO: Support dictinct, filter, order by and null_treatment - Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(self.clone()), args, false, @@ -606,3 +607,49 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } + +pub trait AggregateUDFExprBuilder { + fn order_by(self, order_by: Vec) -> Expr; + fn filter(self, filter: Box) -> Expr; + fn null_treatment(self, null_treatment: NullTreatment) -> Expr; + fn distinct(self) -> Expr; +} + +impl AggregateUDFExprBuilder for Expr { + fn order_by(self, order_by: Vec) -> Expr { + match self { + Expr::AggregateFunction(mut udaf) => { + udaf.order_by = Some(order_by); + Expr::AggregateFunction(udaf) + } + _ => self, + } + } + fn filter(self, filter: Box) -> Expr { + match self { + Expr::AggregateFunction(mut udaf) => { + udaf.filter = Some(filter); + Expr::AggregateFunction(udaf) + } + _ => self, + } + } + fn null_treatment(self, null_treatment: NullTreatment) -> Expr { + match self { + Expr::AggregateFunction(mut udaf) => { + udaf.null_treatment = Some(null_treatment); + Expr::AggregateFunction(udaf) + } + _ => self, + } + } + fn distinct(self) -> Expr { + match self { + Expr::AggregateFunction(mut udaf) => { + udaf.distinct = true; + Expr::AggregateFunction(udaf) + } + _ => self, + } + } +} \ No newline at end of file diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 435d277473c4..fd4713767870 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,20 +31,23 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFExprBuilder, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, }; -make_udaf_expr_and_func!( - FirstValue, - first_value, - "Returns the first value in a group of values.", - first_value_udaf -); +create_func!(FirstValue, first_value_udaf); + +/// Returns the first value in a group of values. +pub fn first_value(expression: Expr, order_by: Option>) -> Expr { + if let Some(order_by) = order_by { + first_value_udaf().call(vec![expression]).order_by(order_by) + } else { + first_value_udaf().call(vec![expression]) + } +} pub struct FirstValue { signature: Signature, diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 6c3348d6c1d6..75bb9dc54719 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -48,24 +48,7 @@ macro_rules! make_udaf_expr_and_func { None, )) } - create_func!($UDAF, $AGGREGATE_UDF_FN); - }; - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN( - $($arg: datafusion_expr::Expr,)* - distinct: bool, - ) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - $AGGREGATE_UDF_FN(), - vec![$($arg),*], - distinct, - None, - None, - None - )) - } + create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,20 +56,17 @@ macro_rules! make_udaf_expr_and_func { #[doc = $DOC] pub fn $EXPR_FN( args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option ) -> datafusion_expr::Expr { datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), args, - distinct, - filter, - order_by, - null_treatment, + false, + None, + None, + None, )) } + create_func!($UDAF, $AGGREGATE_UDF_FN); }; } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 752e2b200741..d0bf51b845ea 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{internal_err, Column, Result}; -use datafusion_expr::expr::AggregateFunction; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, LogicalPlanBuilder}; +use datafusion_expr::{col, AggregateUDFExprBuilder, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] @@ -95,17 +94,14 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf = + let first_value_udaf: std::sync::Arc = config.function_registry().unwrap().udaf("first_value")?; let aggr_expr = select_expr.into_iter().map(|e| { - Expr::AggregateFunction(AggregateFunction::new_udf( - first_value_udaf.clone(), - vec![e], - false, - None, - sort_expr.clone(), - None, - )) + if let Some(order_by) = &sort_expr { + first_value_udaf.call(vec![e]).order_by(order_by.clone()) + } else { + first_value_udaf.call(vec![e]) + } }); let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index deae97fecc96..d4f313c67514 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -647,7 +647,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), - first_value(vec![lit(1)], false, None, None, None), + first_value(lit(1), None), + first_value(lit(1), Some(vec![lit(2)])), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), sum(lit(1)), diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index a5fc13491677..2617c986dba9 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -304,6 +304,16 @@ select log(-1), log(0), sqrt(-1); | rollup(exprs) | Creates a grouping set for rollup sets. | | sum(expr) | Сalculates the sum of `expr`. | +## Aggregate Function Builder + +Import trait `AggregateUDFExprBuilder` and update the arguments directly in `Expr` + +See datafusion-examples/examples/udaf_expr.rs for example usage. + +| Syntax | Equivalent to | +| ------------------------------------------------------ | ----------------------------------- | +| first_value_udaf.call(vec![expr]).order_by(vec![expr]) | first_value(expr, Some(vec![expr])) | + ## Subquery Expressions | Syntax | Description | From 698ab8ffa47c4e6dbac818706bf6e5a1c8b8ec51 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Thu, 6 Jun 2024 20:39:52 +0800 Subject: [PATCH 02/11] fmt Signed-off-by: jayzhan211 --- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 2 +- datafusion/functions-aggregate/src/first_last.rs | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 8edd982261cd..aed7f3321df6 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -81,7 +81,7 @@ pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, AggregateUDFExprBuilder}; +pub use udaf::{AggregateUDF, AggregateUDFExprBuilder, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b5944df5805f..f1dd74ef733d 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -652,4 +652,4 @@ impl AggregateUDFExprBuilder for Expr { _ => self, } } -} \ No newline at end of file +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index fd4713767870..2758ffdc2adb 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,7 +31,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFExprBuilder, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility + Accumulator, AggregateUDFExprBuilder, AggregateUDFImpl, ArrayFunctionSignature, Expr, + Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ From 6994621af562cea91a78af0fc6f276439535370e Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 7 Jun 2024 20:24:08 +0800 Subject: [PATCH 03/11] build Signed-off-by: jayzhan211 --- datafusion-examples/examples/udaf_expr.rs | 5 +- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 94 ++++++++++++++----- .../functions-aggregate/src/first_last.rs | 10 +- .../src/replace_distinct_aggregate.rs | 8 +- 5 files changed, 90 insertions(+), 29 deletions(-) diff --git a/datafusion-examples/examples/udaf_expr.rs b/datafusion-examples/examples/udaf_expr.rs index d655ad44160a..d90bfba5b61c 100644 --- a/datafusion-examples/examples/udaf_expr.rs +++ b/datafusion-examples/examples/udaf_expr.rs @@ -24,7 +24,7 @@ use datafusion::{ }; use datafusion_common::Result; -use datafusion_expr::{col, AggregateUDFExprBuilder}; +use datafusion_expr::{col, AggregateExt}; #[tokio::main] async fn main() -> Result<()> { @@ -36,7 +36,8 @@ async fn main() -> Result<()> { let first_value_udaf = state.aggregate_functions().get("first_value").unwrap(); let first_value_builder = first_value_udaf .call(vec![col("a")]) - .order_by(vec![col("b")]); + .order_by(vec![col("b")]) + .build()?; let first_value_fn = first_value(col("a"), Some(vec![col("b")])); assert_eq!(first_value_builder, first_value_fn); diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index aed7f3321df6..c2d40a7fe4f1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -81,7 +81,7 @@ pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFExprBuilder, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index f1dd74ef733d..4a5ee8554b8c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -27,7 +27,7 @@ use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::{self, Debug, Formatter}; @@ -608,48 +608,100 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { } } -pub trait AggregateUDFExprBuilder { - fn order_by(self, order_by: Vec) -> Expr; - fn filter(self, filter: Box) -> Expr; - fn null_treatment(self, null_treatment: NullTreatment) -> Expr; - fn distinct(self) -> Expr; +pub trait AggregateExt { + fn order_by(self, order_by: Vec) -> AggregateBuilder; + fn filter(self, filter: Box) -> AggregateBuilder; + fn distinct(self) -> AggregateBuilder; + fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; } -impl AggregateUDFExprBuilder for Expr { - fn order_by(self, order_by: Vec) -> Expr { +pub struct AggregateBuilder { + udaf: Option, + order_by: Option>, + filter: Option>, + distinct: bool, + null_treatment: Option, +} + +impl AggregateBuilder { + fn new(udaf: Option) -> Self { + Self { + udaf, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + } + } + + pub fn build(self) -> Result { + if let Some(mut udaf) = self.udaf { + udaf.order_by = self.order_by; + udaf.filter = self.filter; + udaf.distinct = self.distinct; + udaf.null_treatment = self.null_treatment; + return Ok(Expr::AggregateFunction(udaf)); + } + + plan_err!("Expect Expr::AggregateFunction") + } + + pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { + self.order_by = Some(order_by); + self + } + + pub fn filter(mut self, filter: Box) -> AggregateBuilder { + self.filter = Some(filter); + self + } + + pub fn distinct(mut self) -> AggregateBuilder { + self.distinct = true; + self + } + + pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { + self.null_treatment = Some(null_treatment); + self + } +} + +impl AggregateExt for Expr { + fn order_by(self, order_by: Vec) -> AggregateBuilder { match self { Expr::AggregateFunction(mut udaf) => { udaf.order_by = Some(order_by); - Expr::AggregateFunction(udaf) + AggregateBuilder::new(Some(udaf)) } - _ => self, + _ => AggregateBuilder::new(None), } } - fn filter(self, filter: Box) -> Expr { + fn filter(self, filter: Box) -> AggregateBuilder { match self { Expr::AggregateFunction(mut udaf) => { udaf.filter = Some(filter); - Expr::AggregateFunction(udaf) + AggregateBuilder::new(Some(udaf)) } - _ => self, + _ => AggregateBuilder::new(None), } } - fn null_treatment(self, null_treatment: NullTreatment) -> Expr { + fn distinct(self) -> AggregateBuilder { match self { Expr::AggregateFunction(mut udaf) => { - udaf.null_treatment = Some(null_treatment); - Expr::AggregateFunction(udaf) + udaf.distinct = true; + AggregateBuilder::new(Some(udaf)) } - _ => self, + _ => AggregateBuilder::new(None), } } - fn distinct(self) -> Expr { + fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { match self { Expr::AggregateFunction(mut udaf) => { - udaf.distinct = true; - Expr::AggregateFunction(udaf) + udaf.null_treatment = Some(null_treatment); + AggregateBuilder::new(Some(udaf)) } - _ => self, + _ => AggregateBuilder::new(None), } } } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 2758ffdc2adb..c7ed6c94fea3 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,8 +31,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFExprBuilder, AggregateUDFImpl, ArrayFunctionSignature, Expr, - Signature, TypeSignature, Volatility, + Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, + TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ @@ -44,7 +44,11 @@ create_func!(FirstValue, first_value_udaf); /// Returns the first value in a group of values. pub fn first_value(expression: Expr, order_by: Option>) -> Expr { if let Some(order_by) = order_by { - first_value_udaf().call(vec![expression]).order_by(order_by) + first_value_udaf() + .call(vec![expression]) + .order_by(order_by) + .build() + .unwrap() } else { first_value_udaf().call(vec![expression]) } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index d0bf51b845ea..bc55138c6799 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,7 +23,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{internal_err, Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, AggregateUDFExprBuilder, LogicalPlanBuilder}; +use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] @@ -98,7 +98,11 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { config.function_registry().unwrap().udaf("first_value")?; let aggr_expr = select_expr.into_iter().map(|e| { if let Some(order_by) = &sort_expr { - first_value_udaf.call(vec![e]).order_by(order_by.clone()) + first_value_udaf + .call(vec![e]) + .order_by(order_by.clone()) + .build() + .unwrap() } else { first_value_udaf.call(vec![e]) } From 3c10e71300478fd432b75ddf821568be0dde2ace Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 7 Jun 2024 20:27:09 +0800 Subject: [PATCH 04/11] upd user-guide Signed-off-by: jayzhan211 --- docs/source/user-guide/expressions.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 2617c986dba9..081960fee76b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -310,9 +310,9 @@ Import trait `AggregateUDFExprBuilder` and update the arguments directly in `Exp See datafusion-examples/examples/udaf_expr.rs for example usage. -| Syntax | Equivalent to | -| ------------------------------------------------------ | ----------------------------------- | -| first_value_udaf.call(vec![expr]).order_by(vec![expr]) | first_value(expr, Some(vec![expr])) | +| Syntax | Equivalent to | +| ----------------------------------------------------------------------- | ----------------------------------- | +| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build().unwrap() | first_value(expr, Some(vec![expr])) | ## Subquery Expressions From d2fc1d1d1dfc70e76ece1bb19088f863b6d634fd Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 7 Jun 2024 21:34:20 +0800 Subject: [PATCH 05/11] fix builder Signed-off-by: jayzhan211 --- datafusion-examples/examples/udaf_expr.rs | 11 +++++++- datafusion/expr/src/udaf.rs | 28 +++++++++++-------- .../functions-aggregate/src/first_last.rs | 1 + .../src/replace_distinct_aggregate.rs | 1 + 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/datafusion-examples/examples/udaf_expr.rs b/datafusion-examples/examples/udaf_expr.rs index d90bfba5b61c..68dbbd0ccc2e 100644 --- a/datafusion-examples/examples/udaf_expr.rs +++ b/datafusion-examples/examples/udaf_expr.rs @@ -24,7 +24,7 @@ use datafusion::{ }; use datafusion_common::Result; -use datafusion_expr::{col, AggregateExt}; +use datafusion_expr::{col, expr::AggregateFunction, AggregateExt, Expr}; #[tokio::main] async fn main() -> Result<()> { @@ -40,6 +40,15 @@ async fn main() -> Result<()> { .build()?; let first_value_fn = first_value(col("a"), Some(vec![col("b")])); + let first_value_manual = Expr::AggregateFunction(AggregateFunction::new_udf( + first_value_udaf.clone(), + vec![col("a")], + false, + None, + Some(vec![col("b")]), + None, + )); assert_eq!(first_value_builder, first_value_fn); + assert_eq!(first_value_builder, first_value_manual); Ok(()) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 4a5ee8554b8c..3fcfafb9d145 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -670,36 +670,40 @@ impl AggregateBuilder { impl AggregateExt for Expr { fn order_by(self, order_by: Vec) -> AggregateBuilder { match self { - Expr::AggregateFunction(mut udaf) => { - udaf.order_by = Some(order_by); - AggregateBuilder::new(Some(udaf)) + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.order_by = Some(order_by); + builder } _ => AggregateBuilder::new(None), } } fn filter(self, filter: Box) -> AggregateBuilder { match self { - Expr::AggregateFunction(mut udaf) => { - udaf.filter = Some(filter); - AggregateBuilder::new(Some(udaf)) + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.filter = Some(filter); + builder } _ => AggregateBuilder::new(None), } } fn distinct(self) -> AggregateBuilder { match self { - Expr::AggregateFunction(mut udaf) => { - udaf.distinct = true; - AggregateBuilder::new(Some(udaf)) + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.distinct = true; + builder } _ => AggregateBuilder::new(None), } } fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { match self { - Expr::AggregateFunction(mut udaf) => { - udaf.null_treatment = Some(null_treatment); - AggregateBuilder::new(Some(udaf)) + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.null_treatment = Some(null_treatment); + builder } _ => AggregateBuilder::new(None), } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index c7ed6c94fea3..dd38e3487264 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -48,6 +48,7 @@ pub fn first_value(expression: Expr, order_by: Option>) -> Expr { .call(vec![expression]) .order_by(order_by) .build() + // guaranteed to be `Expr::AggregateFunction` .unwrap() } else { first_value_udaf().call(vec![expression]) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index bc55138c6799..b32a88635395 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -102,6 +102,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { .call(vec![e]) .order_by(order_by.clone()) .build() + // guaranteed to be `Expr::AggregateFunction` .unwrap() } else { first_value_udaf.call(vec![e]) From f2fc8d5fcf07af9275f903cdd074b0df15927331 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 8 Jun 2024 09:46:49 -0400 Subject: [PATCH 06/11] Consolidate example in udaf_expr.rs, simplify filter API --- datafusion-examples/examples/expr_api.rs | 38 ++++++++++++++-- datafusion-examples/examples/udaf_expr.rs | 54 ----------------------- datafusion/expr/src/udaf.rs | 10 ++--- docs/source/user-guide/expressions.md | 4 +- 4 files changed, 42 insertions(+), 64 deletions(-) delete mode 100644 datafusion-examples/examples/udaf_expr.rs diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 0082ed6eb9a9..d6164458e7bc 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::DFSchema; use datafusion::error::Result; +use datafusion::functions_aggregate::first_last::first_value_udaf; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; use datafusion::prelude::*; @@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -44,8 +45,9 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; /// also comes with APIs for evaluation, simplification, and analysis. /// /// The code in this example shows how to: -/// 1. Create [`Exprs`] using different APIs: [`main`]` -/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] +/// 1. Create [`Expr`]s using different APIs: [`main`]` +/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`] +/// 2. Evaluate [`Expr`]s against data: [`evaluate_demo`] /// 3. Simplify expressions: [`simplify_demo`] /// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] /// 5. Get the types of the expressions: [`expression_type_demo`] @@ -63,6 +65,9 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to build aggregate functions with a fluent API + expr_fn_demo()?; + // See how to evaluate expressions evaluate_demo()?; @@ -78,6 +83,33 @@ async fn main() -> Result<()> { Ok(()) } +/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the +/// full range of expression types such as aggregates and window functions. +fn expr_fn_demo() -> Result<()> { + // Let's say you want to call the "first_value" aggregate function + let first_value = first_value_udaf(); + + // For example, to create the expression `FIRST_VALUE(price)` + // These expressions can be passed to `DataFrame::aggregate` and other + // APIs that take aggregate expressions. + let agg = first_value.call(vec![col("price")]); + assert_eq!(agg.to_string(), "first_value(price)"); + + // You can use the AggregateExt trait to create more complex aggregates + // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) + let agg = first_value + .call(vec![col("price")]) + .order_by(vec![col("ts")]) + .filter(col("quantity").gt(lit(100))) + .build()?; // build the aggregate + assert_eq!( + agg.to_string(), + "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts]" + ); + + Ok(()) +} + /// DataFusion can also evaluate arbitrary expressions on Arrow arrays. fn evaluate_demo() -> Result<()> { // For example, let's say you have some integers in an array diff --git a/datafusion-examples/examples/udaf_expr.rs b/datafusion-examples/examples/udaf_expr.rs deleted file mode 100644 index 68dbbd0ccc2e..000000000000 --- a/datafusion-examples/examples/udaf_expr.rs +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::{ - execution::{ - config::SessionConfig, - context::{SessionContext, SessionState}, - }, - functions_aggregate::{expr_fn::first_value, register_all}, -}; - -use datafusion_common::Result; -use datafusion_expr::{col, expr::AggregateFunction, AggregateExt, Expr}; - -#[tokio::main] -async fn main() -> Result<()> { - let ctx = SessionContext::new(); - let config = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(config, ctx.runtime_env()); - let _ = register_all(&mut state); - - let first_value_udaf = state.aggregate_functions().get("first_value").unwrap(); - let first_value_builder = first_value_udaf - .call(vec![col("a")]) - .order_by(vec![col("b")]) - .build()?; - - let first_value_fn = first_value(col("a"), Some(vec![col("b")])); - let first_value_manual = Expr::AggregateFunction(AggregateFunction::new_udf( - first_value_udaf.clone(), - vec![col("a")], - false, - None, - Some(vec![col("b")]), - None, - )); - assert_eq!(first_value_builder, first_value_fn); - assert_eq!(first_value_builder, first_value_manual); - Ok(()) -} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3fcfafb9d145..f40d8cb159ba 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -610,7 +610,7 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { pub trait AggregateExt { fn order_by(self, order_by: Vec) -> AggregateBuilder; - fn filter(self, filter: Box) -> AggregateBuilder; + fn filter(self, filter: Expr) -> AggregateBuilder; fn distinct(self) -> AggregateBuilder; fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; } @@ -618,7 +618,7 @@ pub trait AggregateExt { pub struct AggregateBuilder { udaf: Option, order_by: Option>, - filter: Option>, + filter: Option, distinct: bool, null_treatment: Option, } @@ -637,7 +637,7 @@ impl AggregateBuilder { pub fn build(self) -> Result { if let Some(mut udaf) = self.udaf { udaf.order_by = self.order_by; - udaf.filter = self.filter; + udaf.filter = self.filter.map(Box::new); udaf.distinct = self.distinct; udaf.null_treatment = self.null_treatment; return Ok(Expr::AggregateFunction(udaf)); @@ -651,7 +651,7 @@ impl AggregateBuilder { self } - pub fn filter(mut self, filter: Box) -> AggregateBuilder { + pub fn filter(mut self, filter: Expr) -> AggregateBuilder { self.filter = Some(filter); self } @@ -678,7 +678,7 @@ impl AggregateExt for Expr { _ => AggregateBuilder::new(None), } } - fn filter(self, filter: Box) -> AggregateBuilder { + fn filter(self, filter: Expr) -> AggregateBuilder { match self { Expr::AggregateFunction(udaf) => { let mut builder = AggregateBuilder::new(Some(udaf)); diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 081960fee76b..cae9627210e5 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -306,9 +306,9 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Function Builder -Import trait `AggregateUDFExprBuilder` and update the arguments directly in `Expr` +You can also use the `AggregateExt` trait to more easily build Aggregate arguments `Expr`. -See datafusion-examples/examples/udaf_expr.rs for example usage. +See `datafusion-examples/examples/expr_api.rs` for example usage. | Syntax | Equivalent to | | ----------------------------------------------------------------------- | ----------------------------------- | From a471385f921c43616695d4a972528ca6cdaf1c90 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 8 Jun 2024 10:08:09 -0400 Subject: [PATCH 07/11] Add doc strings and examples --- datafusion/expr/src/udaf.rs | 64 ++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index f40d8cb159ba..74ee4ba5b3d8 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -608,13 +608,47 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { } } +/// Extensions for configuring [`Expr::AggregateFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional aggregate options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; +/// # use sqlparser::ast::NullTreatment; +/// # fn count(arg: Expr) -> Expr { todo!{} } +/// # fn first_value(arg: Expr) -> Expr { todo!{} } +/// # fn main() -> Result<()> { +/// use datafusion_expr::AggregateExt; +/// +/// // Create COUNT(x FILTER y > 5) +/// let agg = count(col("x")) +/// .filter(col("y").gt(lit(5))) +/// .build()?; +/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) +/// let agg = first_value(col("x")) +/// .order_by(vec![col("y")]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` pub trait AggregateExt { + /// Add `ORDER BY ` fn order_by(self, order_by: Vec) -> AggregateBuilder; + /// Add `FILTER ` fn filter(self, filter: Expr) -> AggregateBuilder; + /// Add `DISTINCT` fn distinct(self) -> AggregateBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; } +/// Implementation of [`AggregateExt`]. +/// +/// See [`AggregateExec`] for usage and examples pub struct AggregateBuilder { udaf: Option, order_by: Option>, @@ -624,6 +658,8 @@ pub struct AggregateBuilder { } impl AggregateBuilder { + /// Create a new `AggregateBuilder`, see [`AggregateExt`] + fn new(udaf: Option) -> Self { Self { udaf, @@ -634,33 +670,45 @@ impl AggregateBuilder { } } + /// Updates and returns the in progress [`Expr::AggregateFunction`] + /// + /// # Errors: + /// + /// Returns an error of this builder [`AggregateExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] pub fn build(self) -> Result { - if let Some(mut udaf) = self.udaf { - udaf.order_by = self.order_by; - udaf.filter = self.filter.map(Box::new); - udaf.distinct = self.distinct; - udaf.null_treatment = self.null_treatment; - return Ok(Expr::AggregateFunction(udaf)); - } + let Some(mut udaf) = self.udaf else { + return plan_err!( + "AggregateExt can only be used with Expr::AggregateFunction" + ); + }; - plan_err!("Expect Expr::AggregateFunction") + udaf.order_by = self.order_by; + udaf.filter = self.filter.map(Box::new); + udaf.distinct = self.distinct; + udaf.null_treatment = self.null_treatment; + Ok(Expr::AggregateFunction(udaf)) } + /// Add `ORDER BY ` pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { self.order_by = Some(order_by); self } + /// Add `FILTER ` pub fn filter(mut self, filter: Expr) -> AggregateBuilder { self.filter = Some(filter); self } + /// Add `DISTINCT` pub fn distinct(mut self) -> AggregateBuilder { self.distinct = true; self } + /// Add `RESPECT NULLS` or `IGNORE NULLS` pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { self.null_treatment = Some(null_treatment); self From fc5a6a5e7efc704527748f0be58ce6a4854ca9d7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 8 Jun 2024 10:47:50 -0400 Subject: [PATCH 08/11] Add tests and checks --- datafusion/core/tests/expr_api/mod.rs | 190 +++++++++++++++++++++++++- datafusion/expr/src/udaf.rs | 36 ++++- 2 files changed, 217 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 1db5aa9f235a..7085333bee03 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. -use arrow::util::pretty::pretty_format_columns; +use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use arrow_array::builder::{ListBuilder, StringBuilder}; -use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray}; +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; -use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_common::{assert_contains, DFSchema, ScalarValue}; +use datafusion_expr::AggregateExt; use datafusion_functions::core::expr_ext::FieldAccessor; +use datafusion_functions_aggregate::first_last::first_value_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor}; +use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan use std::sync::{Arc, OnceLock}; @@ -162,6 +166,183 @@ fn test_list_range() { ); } +#[tokio::test] +async fn test_aggregate_error() { + let err = first_value_udaf() + .call(vec![col("props")]) + // not a sort column + .order_by(vec![col("id")]) + .build() + .unwrap_err() + .to_string(); + assert_contains!( + err, + "Error during planning: ORDER BY expressions must be Expr::Sort" + ); +} + +#[tokio::test] +async fn test_aggregate_ext_order_by() { + let agg = first_value_udaf().call(vec![col("props")]); + + // ORDER BY id ASC + let agg_asc = agg + .clone() + .order_by(vec![col("id").sort(true, true)]) + .build() + .unwrap() + .alias("asc"); + + // ORDER BY id DESC + let agg_desc = agg + .order_by(vec![col("id").sort(false, true)]) + .build() + .unwrap() + .alias("desc"); + + evaluate_agg_test( + agg_asc, + vec![ + "+-----------------+", + "| asc |", + "+-----------------+", + "| {a: 2021-02-01} |", + "+-----------------+", + ], + ) + .await; + + evaluate_agg_test( + agg_desc, + vec![ + "+-----------------+", + "| desc |", + "+-----------------+", + "| {a: 2021-02-03} |", + "+-----------------+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_filter() { + let agg = first_value_udaf() + .call(vec![col("i")]) + .order_by(vec![col("i").sort(true, true)]) + .filter(col("i").is_not_null()) + .build() + .unwrap() + .alias("val"); + + #[rustfmt::skip] + evaluate_agg_test( + agg, + vec![ + "+-----+", + "| val |", + "+-----+", + "| 5 |", + "+-----+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_distinct() { + let agg = sum_udaf() + .call(vec![lit(5)]) + // distinct sum should be 5, not 15 + .distinct() + .build() + .unwrap() + .alias("distinct"); + + evaluate_agg_test( + agg, + vec![ + "+----------+", + "| distinct |", + "+----------+", + "| 5 |", + "+----------+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_null_treatment() { + let agg = first_value_udaf() + .call(vec![col("i")]) + .order_by(vec![col("i").sort(true, true)]); + + let agg_respect = agg + .clone() + .null_treatment(NullTreatment::RespectNulls) + .build() + .unwrap() + .alias("respect"); + + let agg_ignore = agg + .null_treatment(NullTreatment::IgnoreNulls) + .build() + .unwrap() + .alias("ignore"); + + evaluate_agg_test( + agg_respect, + vec![ + "+---------+", + "| respect |", + "+---------+", + "| |", + "+---------+", + ], + ) + .await; + + evaluate_agg_test( + agg_ignore, + vec![ + "+--------+", + "| ignore |", + "+--------+", + "| 5 |", + "+--------+", + ], + ) + .await; +} + +/// Evaluates the specified expr as an aggregate and compares the result to the +/// expected result. +async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { + let batch = test_batch(); + + let ctx = SessionContext::new(); + let group_expr = vec![]; + let agg_expr = vec![expr]; + let result = ctx + .read_batch(batch) + .unwrap() + .aggregate(group_expr, agg_expr) + .unwrap() + .collect() + .await + .unwrap(); + + let result = pretty_format_batches(&result).unwrap().to_string(); + let actual_lines = result.lines().collect::>(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); +} + /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided /// `RecordBatch` and compares the result to the expected result. fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { @@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch { TEST_BATCH .get_or_init(|| { let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"])); + let int_array: ArrayRef = + Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)])); // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" } let struct_array: ArrayRef = Arc::from(StructArray::from(vec![( @@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch { RecordBatch::try_from_iter(vec![ ("id", string_array), + ("i", int_array), ("props", struct_array), ("list", list_array), ]) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 74ee4ba5b3d8..3d4b6780785d 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -628,8 +628,9 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { /// .filter(col("y").gt(lit(5))) /// .build()?; /// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) +/// let sort_expr = col("y").sort(true, true); /// let agg = first_value(col("x")) -/// .order_by(vec![col("y")]) +/// .order_by(vec![sort_expr]) /// .null_treatment(NullTreatment::IgnoreNulls) /// .build()?; /// # Ok(()) @@ -637,6 +638,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { /// ``` pub trait AggregateExt { /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] fn order_by(self, order_by: Vec) -> AggregateBuilder; /// Add `FILTER ` fn filter(self, filter: Expr) -> AggregateBuilder; @@ -649,6 +652,7 @@ pub trait AggregateExt { /// Implementation of [`AggregateExt`]. /// /// See [`AggregateExec`] for usage and examples +#[derive(Debug, Clone)] pub struct AggregateBuilder { udaf: Option, order_by: Option>, @@ -677,20 +681,40 @@ impl AggregateBuilder { /// Returns an error of this builder [`AggregateExt`] was used with an /// `Expr` variant other than [`Expr::AggregateFunction`] pub fn build(self) -> Result { - let Some(mut udaf) = self.udaf else { + let Self { + udaf, + order_by, + filter, + distinct, + null_treatment, + } = self; + + let Some(mut udaf) = udaf else { return plan_err!( "AggregateExt can only be used with Expr::AggregateFunction" ); }; - udaf.order_by = self.order_by; - udaf.filter = self.filter.map(Box::new); - udaf.distinct = self.distinct; - udaf.null_treatment = self.null_treatment; + if let Some(order_by) = &order_by { + for expr in order_by.iter() { + if !matches!(expr, Expr::Sort(_)) { + return plan_err!( + "ORDER BY expressions must be Expr::Sort, found {expr:?}" + ); + } + } + } + + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; Ok(Expr::AggregateFunction(udaf)) } /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { self.order_by = Some(order_by); self From d547853e30b3477c305f967a6120e7da3c3c4264 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 8 Jun 2024 11:02:46 -0400 Subject: [PATCH 09/11] Improve documentation more --- datafusion-examples/examples/expr_api.rs | 10 +++++----- datafusion/expr/src/expr.rs | 15 ++++++++++++++- datafusion/expr/src/udaf.rs | 2 +- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index d6164458e7bc..b419af2847f7 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -47,10 +47,10 @@ use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; /// The code in this example shows how to: /// 1. Create [`Expr`]s using different APIs: [`main`]` /// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`] -/// 2. Evaluate [`Expr`]s against data: [`evaluate_demo`] -/// 3. Simplify expressions: [`simplify_demo`] -/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] -/// 5. Get the types of the expressions: [`expression_type_demo`] +/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`] +/// 4. Simplify expressions: [`simplify_demo`] +/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] +/// 6. Get the types of the expressions: [`expression_type_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the @@ -65,7 +65,7 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); - // See how to build aggregate functions with a fluent API + // See how to build aggregate functions with the expr_fn API expr_fn_demo()?; // See how to evaluate expressions diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 1abd8c97ee10..10e2edd17b21 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -255,19 +255,23 @@ pub enum Expr { /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. /// + /// ```text /// CASE WHEN condition THEN result /// [WHEN ...] /// [ELSE result] /// END + /// ``` /// /// The second form uses a base expression and then a series of "when" clauses that match on a /// literal value. /// + /// ```text /// CASE expression /// WHEN value THEN result /// [WHEN ...] /// [ELSE result] /// END + /// ``` Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. @@ -279,7 +283,12 @@ pub enum Expr { Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of an aggregate built-in function with arguments. + /// Calls an aggregate function with arguments, and optional + /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. + /// + /// See also [`AggregateExt`] to set these fields. + /// + /// [`AggregateExt`]: crate::udaf::AggregateExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -623,6 +632,10 @@ impl AggregateFunctionDefinition { } /// Aggregate function +/// +/// See also [`AggregateExt`] to set these fields on `Expr` +/// +/// [`AggregateExt`]: crate::udaf::AggregateExt #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3d4b6780785d..a248518c2d94 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -651,7 +651,7 @@ pub trait AggregateExt { /// Implementation of [`AggregateExt`]. /// -/// See [`AggregateExec`] for usage and examples +/// See [`AggregateExt`] for usage and examples #[derive(Debug, Clone)] pub struct AggregateBuilder { udaf: Option, From 980c30c20b345646688b3e7221b3d5dbb498b858 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 8 Jun 2024 14:51:25 -0400 Subject: [PATCH 10/11] fixup --- datafusion-examples/examples/expr_api.rs | 4 ++-- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index b419af2847f7..e667de3ddcb5 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -99,12 +99,12 @@ fn expr_fn_demo() -> Result<()> { // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) - .order_by(vec![col("ts")]) + .order_by(vec![col("ts").sort(false, false)]) .filter(col("quantity").gt(lit(100))) .build()?; // build the aggregate assert_eq!( agg.to_string(), - "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts]" + "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST ]" ); Ok(()) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e4c84efb720a..4f35b82e4908 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,8 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use prost::Message; + use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; @@ -64,8 +66,6 @@ use datafusion_proto::logical_plan::{ }; use datafusion_proto::protobuf; -use prost::Message; - #[cfg(feature = "json")] fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { let string = serde_json::to_string(proto).unwrap(); @@ -648,7 +648,7 @@ async fn roundtrip_expr_api() -> Result<()> { ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), first_value(lit(1), None), - first_value(lit(1), Some(vec![lit(2)])), + first_value(lit(1), Some(vec![lit(2).sort(true, true)])), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), sum(lit(1)), From a851efce91fc2b3084e4bf826282d2b147b1eb20 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 9 Jun 2024 11:26:50 +0800 Subject: [PATCH 11/11] rm spce Signed-off-by: jayzhan211 --- datafusion-examples/examples/expr_api.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index e667de3ddcb5..591f6ac3de95 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -104,7 +104,7 @@ fn expr_fn_demo() -> Result<()> { .build()?; // build the aggregate assert_eq!( agg.to_string(), - "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST ]" + "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST]" ); Ok(())