Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support UDAF to align Builtin aggregate function #10493

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use datafusion_common::stats::Precision;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::udaf::AggregateFunctionExpr;

/// Optimizer that uses available statistics for aggregate functions
#[derive(Default)]
Expand Down Expand Up @@ -57,13 +58,9 @@ impl PhysicalOptimizerRule for AggregateStatistics {
let mut projections = vec![];
for expr in partial_agg_exec.aggr_expr() {
if let Some((non_null_rows, name)) =
take_optimizable_column_count(&**expr, &stats)
take_optimizable_column_and_table_count(&**expr, &stats)
{
projections.push((expressions::lit(non_null_rows), name.to_owned()));
} else if let Some((num_rows, name)) =
take_optimizable_table_count(&**expr, &stats)
{
projections.push((expressions::lit(num_rows), name.to_owned()));
} else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) {
projections.push((expressions::lit(min), name.to_owned()));
} else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) {
Expand Down Expand Up @@ -137,43 +134,48 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>>
None
}

/// If this agg_expr is a count that is exactly defined in the statistics, return it.
fn take_optimizable_table_count(
/// If this agg_expr is a count that can be exactly derived from the statistics, return it.
fn take_optimizable_column_and_table_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
&stats.num_rows,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
// TODO implementing Eq on PhysicalExpr would help a lot here
if casted_expr.expressions().len() == 1 {
if let Some(lit_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
casted_expr.name().to_owned(),
));
let col_stats = &stats.column_statistics;
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "COUNT" && !agg_expr.is_distinct() {
if let Precision::Exact(num_rows) = stats.num_rows {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
let current_val = &col_stats[col_expr.index()].null_count;
if let &Precision::Exact(val) = current_val {
return Some((
ScalarValue::Int64(Some((num_rows - val) as i64)),
agg_expr.name().to_string(),
));
}
} else if let Some(lit_expr) =
exprs[0].as_any().downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
agg_expr.name().to_string(),
));
}
}
}
}
}
}
None
}

/// If this agg_expr is a count that can be exactly derived from the statistics, return it.
fn take_optimizable_column_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
let col_stats = &stats.column_statistics;
if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
// TODO: Remove this after revmoing Builtin Count
else if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
&stats.num_rows,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
// TODO implementing Eq on PhysicalExpr would help a lot here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.rs/datafusion/latest/datafusion/physical_expr/trait.PhysicalExpr.html# seems to imply it is possible to to compare PhysicalExprs (perhaps expr.eq(other_expr.as_any()) 🤔 )

I don't know PhysicalExpr doesn't implement PartialEq directly

pub trait PhysicalExpr: ... PartialEq<dyn Any> {`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had not take a look carefully to the comment here, since they are not needed after removing builtin function

if casted_expr.expressions().len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = casted_expr.expressions()[0]
Expand All @@ -187,6 +189,16 @@ fn take_optimizable_column_count(
casted_expr.name().to_string(),
));
}
} else if let Some(lit_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
casted_expr.name().to_owned(),
));
}
}
}
}
Expand Down
31 changes: 8 additions & 23 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,31 +252,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
func_def,
distinct,
args,
filter,
filter: _,
order_by,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => create_function_physical_name(
func_def.name(),
*distinct,
args,
order_by.as_ref(),
),
AggregateFunctionDefinition::UDF(fun) => {
// TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}

let names = args
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?;
Ok(format!("{}({})", fun.name(), names.join(",")))
}
},
}) => create_function_physical_name(
func_def.name(),
*distinct,
args,
order_by.as_ref(),
),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
Expand Down Expand Up @@ -1941,6 +1925,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
physical_input_schema,
name,
ignore_nulls,
*distinct,
)?;
(agg_expr, filter, physical_sort_exprs)
}
Expand Down
12 changes: 2 additions & 10 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1891,16 +1891,8 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) -> Result<()> {
order_by,
null_treatment,
}) => {
match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
write_function_name(w, func_def.name(), *distinct, args)?;
}
AggregateFunctionDefinition::UDF(fun) => {
write!(w, "{}(", fun.name())?;
write_names_join(w, args, ",")?;
write!(w, ")")?;
}
};
write_function_name(w, func_def.name(), *distinct, args)?;

if let Some(fe) = filter {
write!(w, " FILTER (WHERE {fe})")?;
};
Expand Down
44 changes: 30 additions & 14 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
use datafusion_expr::{
aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition,
};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand Down Expand Up @@ -54,23 +56,37 @@ fn is_wildcard(expr: &Expr) -> bool {
}

fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(
&aggregate_function.func_def,
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && aggregate_function.args.len() == 1
&& is_wildcard(&aggregate_function.args[0])
match aggregate_function {
AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
..
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true,
AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
),
args,
..
} if args.len() == 1 && is_wildcard(&args[0]) => true,
_ => false,
}
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
matches!(
&window_function.fun,
let args = &window_function.args;
match window_function.fun {
WindowFunctionDefinition::AggregateFunction(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && window_function.args.len() == 1
&& is_wildcard(&window_function.args[0])
aggregate_function::AggregateFunction::Count,
) if args.len() == 1 && is_wildcard(&args[0]) => true,
WindowFunctionDefinition::AggregateUDF(ref udaf)
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should document the COUNT somewhere. Also should it do a case insensitive comparison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should follow what is defined in Count struct. case sensitive is not needed, it is strictly comparing with fn name().

{
true
}
_ => false,
}
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Expand Down
25 changes: 25 additions & 0 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
} else if let Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(fun),
distinct,
args,
filter,
order_by,
null_treatment: _,
}) = expr
{
if filter.is_some() || order_by.is_some() {
return Ok(false);
}
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e.canonical_name());
}
} else if fun.name() != "SUM"
&& fun.name() != "MIN"
&& fun.name() != "MAX"
{
return Ok(false);
}
} else {
return Ok(false);
}
}
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
Expand Down
9 changes: 9 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use self::utils::{down_cast_any_ref, ordering_fields};

/// Creates a physical expression of the UDAF, that includes all necessary type coercion.
/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF.
#[allow(clippy::too_many_arguments)]
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
Expand All @@ -42,6 +43,7 @@ pub fn create_aggregate_expr(
schema: &Schema,
name: impl Into<String>,
ignore_nulls: bool,
is_distinct: bool,
) -> Result<Arc<dyn AggregateExpr>> {
let input_exprs_types = input_phy_exprs
.iter()
Expand Down Expand Up @@ -71,6 +73,7 @@ pub fn create_aggregate_expr(
ordering_req: ordering_req.to_vec(),
ignore_nulls,
ordering_fields,
is_distinct,
}))
}

Expand Down Expand Up @@ -162,13 +165,19 @@ pub struct AggregateFunctionExpr {
ordering_req: LexOrdering,
ignore_nulls: bool,
ordering_fields: Vec<Field>,
is_distinct: bool,
}

impl AggregateFunctionExpr {
/// Return the `AggregateUDF` used by this `AggregateFunctionExpr`
pub fn fun(&self) -> &AggregateUDF {
&self.fun
}

/// Return if the aggregation is distinct
pub fn is_distinct(&self) -> bool {
self.is_distinct
}
}

impl AggregateExpr for AggregateFunctionExpr {
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pub fn create_window_expr(
input_schema,
name,
ignore_nulls,
false,
)?;
window_expr_from_aggregate_expr(
partition_by,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
let sort_exprs = &[];
let ordering_req = &[];
let ignore_nulls = false;
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls)
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false)
}
}
}).transpose()?.ok_or_else(|| {
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
&schema,
"example_agg",
false,
false,
)?];

roundtrip_test_with_context(
Expand Down
9 changes: 6 additions & 3 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)?;
let order_by = (!order_by.is_empty()).then_some(order_by);
let args = self.function_args_to_expr(args, schema, planner_context)?;
// TODO: Support filter and distinct for UDAFs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

let filter: Option<Box<Expr>> = filter
.map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context))
.transpose()?
.map(Box::new);
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
fm,
args,
false,
None,
distinct,
filter,
order_by,
null_treatment,
)));
Expand Down
5 changes: 4 additions & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,10 @@ pub fn to_substrait_agg_measure(
arguments,
sorts,
output_type: None,
invocation: AggregationInvocation::All as i32,
invocation: match distinct {
true => AggregationInvocation::Distinct as i32,
false => AggregationInvocation::All as i32,
},
phase: AggregationPhase::Unspecified as i32,
args: vec![],
options: vec![],
Expand Down