From 659b28fc72be9357aef61c5e2ccc5a27b98fc9eb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 14 May 2024 09:18:53 +0800 Subject: [PATCH 1/2] align udaf and builtin Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 28 ++++--------------- datafusion/expr/src/expr.rs | 12 ++------ datafusion/sql/src/expr/function.rs | 9 ++++-- .../substrait/src/logical_plan/producer.rs | 5 +++- 4 files changed, 18 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d4a9a949fc41..f079e964da28 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -255,28 +255,12 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { filter, order_by, null_treatment: _, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(..) => create_function_physical_name( - func_def.name(), - *distinct, - args, - order_by.as_ref(), - ), - AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter by in AggregateUDF - if filter.is_some() { - return exec_err!( - "aggregate expression with filter is not supported" - ); - } - - let names = args - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()?; - Ok(format!("{}({})", fun.name(), names.join(","))) - } - }, + }) => create_function_physical_name( + func_def.name(), + *distinct, + args, + order_by.as_ref(), + ), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 660a45c27a29..4e0c5bb5291a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1891,16 +1891,8 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { order_by, null_treatment, }) => { - match func_def { - AggregateFunctionDefinition::BuiltIn(..) => { - write_function_name(w, func_def.name(), *distinct, args)?; - } - AggregateFunctionDefinition::UDF(fun) => { - write!(w, "{}(", fun.name())?; - write_names_join(w, args, ",")?; - write!(w, ")")?; - } - }; + write_function_name(w, func_def.name(), *distinct, args)?; + if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3adf2960784d..dc0ddd4714dd 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -229,12 +229,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; - // TODO: Support filter and distinct for UDAFs + let filter: Option> = filter + .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) + .transpose()? + .map(Box::new); return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( fm, args, - false, - None, + distinct, + filter, order_by, null_treatment, ))); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index db5d341bc225..6f0738c38df5 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -722,7 +722,10 @@ pub fn to_substrait_agg_measure( arguments, sorts, output_type: None, - invocation: AggregationInvocation::All as i32, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, phase: AggregationPhase::Unspecified as i32, args: vec![], options: vec![], From 77daaafdc61926048c376eb60e0ea09adc17f987 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 14 May 2024 09:34:47 +0800 Subject: [PATCH 2/2] add more Signed-off-by: jayzhan211 --- .../aggregate_statistics.rs | 76 +++++++++++-------- datafusion/core/src/physical_planner.rs | 3 +- .../src/analyzer/count_wildcard_rule.rs | 44 +++++++---- .../src/single_distinct_to_groupby.rs | 25 ++++++ .../physical-expr-common/src/aggregate/mod.rs | 9 +++ datafusion/physical-plan/src/windows/mod.rs | 1 + datafusion/proto/src/physical_plan/mod.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 1 + 8 files changed, 113 insertions(+), 48 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 505748860388..1a82dac4658c 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -30,6 +30,7 @@ use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::udaf::AggregateFunctionExpr; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] @@ -57,13 +58,9 @@ impl PhysicalOptimizerRule for AggregateStatistics { let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { if let Some((non_null_rows, name)) = - take_optimizable_column_count(&**expr, &stats) + take_optimizable_column_and_table_count(&**expr, &stats) { projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((num_rows, name)) = - take_optimizable_table_count(&**expr, &stats) - { - projections.push((expressions::lit(num_rows), name.to_owned())); } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { projections.push((expressions::lit(min), name.to_owned())); } else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) { @@ -137,43 +134,48 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that is exactly defined in the statistics, return it. -fn take_optimizable_table_count( +/// If this agg_expr is a count that can be exactly derived from the statistics, return it. +fn take_optimizable_column_and_table_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); + let col_stats = &stats.column_statistics; + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name() == "COUNT" && !agg_expr.is_distinct() { + if let Precision::Exact(num_rows) = stats.num_rows { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() + { + let current_val = &col_stats[col_expr.index()].null_count; + if let &Precision::Exact(val) = current_val { + return Some(( + ScalarValue::Int64(Some((num_rows - val) as i64)), + agg_expr.name().to_string(), + )); + } + } else if let Some(lit_expr) = + exprs[0].as_any().downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(( + ScalarValue::Int64(Some(num_rows as i64)), + agg_expr.name().to_string(), + )); + } + } } } } } - None -} - -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_count( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( + // TODO: Remove this after revmoing Builtin Count + else if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( &stats.num_rows, agg_expr.as_any().downcast_ref::(), ) { + // TODO implementing Eq on PhysicalExpr would help a lot here if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] @@ -187,6 +189,16 @@ fn take_optimizable_column_count( casted_expr.name().to_string(), )); } + } else if let Some(lit_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(( + ScalarValue::Int64(Some(num_rows as i64)), + casted_expr.name().to_owned(), + )); + } } } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f079e964da28..406196a59146 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -252,7 +252,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { func_def, distinct, args, - filter, + filter: _, order_by, null_treatment: _, }) => create_function_physical_name( @@ -1925,6 +1925,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_input_schema, name, ignore_nulls, + *distinct, )?; (agg_expr, filter, physical_sort_exprs) } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index a607d49ef967..dfbd5f5632ee 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,7 +25,9 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; +use datafusion_expr::{ + aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, +}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -54,23 +56,37 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - matches!( - &aggregate_function.func_def, - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ) - ) && aggregate_function.args.len() == 1 - && is_wildcard(&aggregate_function.args[0]) + match aggregate_function { + AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(udf), + args, + .. + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, + AggregateFunction { + func_def: + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::Count, + ), + args, + .. + } if args.len() == 1 && is_wildcard(&args[0]) => true, + _ => false, + } } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { - matches!( - &window_function.fun, + let args = &window_function.args; + match window_function.fun { WindowFunctionDefinition::AggregateFunction( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ) - ) && window_function.args.len() == 1 - && is_wildcard(&window_function.args[0]) + aggregate_function::AggregateFunction::Count, + ) if args.len() == 1 && is_wildcard(&args[0]) => true, + WindowFunctionDefinition::AggregateUDF(ref udaf) + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => + { + true + } + _ => false, + } } fn analyze_internal(plan: LogicalPlan) -> Result> { diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index aaf4667fb000..0cad797f2cc6 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -90,6 +90,31 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { } else if !matches!(fun, Sum | Min | Max) { return Ok(false); } + } else if let Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(fun), + distinct, + args, + filter, + order_by, + null_treatment: _, + }) = expr + { + if filter.is_some() || order_by.is_some() { + return Ok(false); + } + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if fun.name() != "SUM" + && fun.name() != "MIN" + && fun.name() != "MAX" + { + return Ok(false); + } + } else { + return Ok(false); } } Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index d2e3414fbfce..05641b373b72 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -34,6 +34,7 @@ use self::utils::{down_cast_any_ref, ordering_fields}; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. +#[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], @@ -42,6 +43,7 @@ pub fn create_aggregate_expr( schema: &Schema, name: impl Into, ignore_nulls: bool, + is_distinct: bool, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() @@ -71,6 +73,7 @@ pub fn create_aggregate_expr( ordering_req: ordering_req.to_vec(), ignore_nulls, ordering_fields, + is_distinct, })) } @@ -162,6 +165,7 @@ pub struct AggregateFunctionExpr { ordering_req: LexOrdering, ignore_nulls: bool, ordering_fields: Vec, + is_distinct: bool, } impl AggregateFunctionExpr { @@ -169,6 +173,11 @@ impl AggregateFunctionExpr { pub fn fun(&self) -> &AggregateUDF { &self.fun } + + /// Return if the aggregation is distinct + pub fn is_distinct(&self) -> bool { + self.is_distinct + } } impl AggregateExpr for AggregateFunctionExpr { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index ff60329ce179..d1223f78808c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -103,6 +103,7 @@ pub fn create_window_expr( input_schema, name, ignore_nulls, + false, )?; window_expr_from_aggregate_expr( partition_by, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1c5ba861d297..4de0b7c06d45 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -525,7 +525,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let sort_exprs = &[]; let ordering_req = &[]; let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c2018352c7cf..30a28081edff 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -426,6 +426,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &schema, "example_agg", false, + false, )?]; roundtrip_test_with_context(