Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce expr builder for aggregate function #10560

Merged
merged 12 commits into from
Jun 9, 2024
45 changes: 45 additions & 0 deletions datafusion-examples/examples/udaf_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion::{
execution::{
config::SessionConfig,
context::{SessionContext, SessionState},
},
functions_aggregate::{expr_fn::first_value, register_all},
};

use datafusion_common::Result;
use datafusion_expr::{col, AggregateExt};

#[tokio::main]
async fn main() -> Result<()> {
alamb marked this conversation as resolved.
Show resolved Hide resolved
let ctx = SessionContext::new();
let config = SessionConfig::new();
let mut state = SessionState::new_with_config_rt(config, ctx.runtime_env());
let _ = register_all(&mut state);

let first_value_udaf = state.aggregate_functions().get("first_value").unwrap();
let first_value_builder = first_value_udaf
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
.call(vec![col("a")])
.order_by(vec![col("b")])
.build()?;

let first_value_fn = first_value(col("a"), Some(vec![col("b")]));
assert_eq!(first_value_builder, first_value_fn);
Ok(())
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use signature::{
ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
105 changes: 102 additions & 3 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::expr::AggregateFunction;
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
Expand All @@ -26,7 +27,8 @@ use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result};
use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
Expand Down Expand Up @@ -139,8 +141,7 @@ impl AggregateUDF {
/// This utility allows using the UDAF without requiring access to
/// the registry, such as with the DataFrame API.
pub fn call(&self, args: Vec<Expr>) -> Expr {
// TODO: Support dictinct, filter, order by and null_treatment
Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
Expand Down Expand Up @@ -606,3 +607,101 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
(self.accumulator)(acc_args)
}
}

pub trait AggregateExt {
fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder;
fn filter(self, filter: Box<Expr>) -> AggregateBuilder;
fn distinct(self) -> AggregateBuilder;
fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder;
}

pub struct AggregateBuilder {
udaf: Option<AggregateFunction>,
order_by: Option<Vec<Expr>>,
filter: Option<Box<Expr>>,
distinct: bool,
null_treatment: Option<NullTreatment>,
}

impl AggregateBuilder {
fn new(udaf: Option<AggregateFunction>) -> Self {
Self {
udaf,
order_by: None,
filter: None,
distinct: false,
null_treatment: None,
}
}

pub fn build(self) -> Result<Expr> {
if let Some(mut udaf) = self.udaf {
udaf.order_by = self.order_by;
udaf.filter = self.filter;
udaf.distinct = self.distinct;
udaf.null_treatment = self.null_treatment;
return Ok(Expr::AggregateFunction(udaf));
}

plan_err!("Expect Expr::AggregateFunction")
}

pub fn order_by(mut self, order_by: Vec<Expr>) -> AggregateBuilder {
self.order_by = Some(order_by);
self
}

pub fn filter(mut self, filter: Box<Expr>) -> AggregateBuilder {
self.filter = Some(filter);
self
}

pub fn distinct(mut self) -> AggregateBuilder {
self.distinct = true;
self
}

pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder {
self.null_treatment = Some(null_treatment);
self
}
}

impl AggregateExt for Expr {
fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.order_by = Some(order_by);
AggregateBuilder::new(Some(udaf))
}
_ => AggregateBuilder::new(None),
}
}
fn filter(self, filter: Box<Expr>) -> AggregateBuilder {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.filter = Some(filter);
AggregateBuilder::new(Some(udaf))
}
_ => AggregateBuilder::new(None),
}
}
fn distinct(self) -> AggregateBuilder {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.distinct = true;
AggregateBuilder::new(Some(udaf))
}
_ => AggregateBuilder::new(None),
}
}
fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder {
match self {
Expr::AggregateFunction(mut udaf) => {
udaf.null_treatment = Some(null_treatment);
AggregateBuilder::new(Some(udaf))
}
_ => AggregateBuilder::new(None),
}
}
}
24 changes: 16 additions & 8 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,28 @@ use datafusion_common::{
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, TypeSignature,
Volatility,
Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature,
TypeSignature, Volatility,
};
use datafusion_physical_expr_common::aggregate::utils::get_sort_options;
use datafusion_physical_expr_common::sort_expr::{
limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr,
};

make_udaf_expr_and_func!(
FirstValue,
first_value,
"Returns the first value in a group of values.",
first_value_udaf
);
create_func!(FirstValue, first_value_udaf);

/// Returns the first value in a group of values.
pub fn first_value(expression: Expr, order_by: Option<Vec<Expr>>) -> Expr {
if let Some(order_by) = order_by {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is pretty cool indeed

first_value_udaf()
.call(vec![expression])
.order_by(order_by)
.build()
.unwrap()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine to unwrap since udaf.call() is guaranteed to be Expr::AggregateFunction

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% agree

Suggested change
.unwrap()
// guaranteed to be `Expr::AggregateFunction`
.unwrap()

} else {
first_value_udaf().call(vec![expression])
}
}

pub struct FirstValue {
signature: Signature,
Expand Down
32 changes: 6 additions & 26 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,45 +48,25 @@ macro_rules! make_udaf_expr_and_func {
None,
))
}
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to introduce another macro for distinct, so we have count_distinct, count_distinct_builder expr::fn.

I change to order_by because first_value needs it. Also, I change the expression for first_value to single expression, since it does not expect variadic args.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of count_distinct_builder what do you think about adding a distinct type method instead?

Like

let agg = count_builder()
  .args(col("a"))
  .distinct()
  .build()?

🤔

// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
$($arg: datafusion_expr::Expr,)*
distinct: bool,
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
distinct,
None,
None,
None
))
}

create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
args: Vec<datafusion_expr::Expr>,
distinct: bool,
filter: Option<Box<datafusion_expr::Expr>>,
order_by: Option<Vec<datafusion_expr::Expr>>,
null_treatment: Option<sqlparser::ast::NullTreatment>
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
args,
distinct,
filter,
order_by,
null_treatment,
false,
None,
None,
None,
))
}

create_func!($UDAF, $AGGREGATE_UDF_FN);
};
}
Expand Down
22 changes: 11 additions & 11 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Column, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, LogicalPlanBuilder};
use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};

/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
Expand Down Expand Up @@ -95,17 +94,18 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
let expr_cnt = on_expr.len();

// Construct the aggregation expression to be used to fetch the selected expressions.
let first_value_udaf =
let first_value_udaf: std::sync::Arc<datafusion_expr::AggregateUDF> =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf.clone(),
vec![e],
false,
None,
sort_expr.clone(),
None,
))
if let Some(order_by) = &sort_expr {
first_value_udaf
.call(vec![e])
.order_by(order_by.clone())
.build()
.unwrap()
} else {
first_value_udaf.call(vec![e])
}
});

let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
Expand Down
3 changes: 2 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ async fn roundtrip_expr_api() -> Result<()> {
lit(1),
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)),
first_value(vec![lit(1)], false, None, None, None),
first_value(lit(1), None),
first_value(lit(1), Some(vec![lit(2)])),
covar_samp(lit(1.5), lit(2.2)),
covar_pop(lit(1.5), lit(2.2)),
sum(lit(1)),
Expand Down
10 changes: 10 additions & 0 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ select log(-1), log(0), sqrt(-1);
| rollup(exprs) | Creates a grouping set for rollup sets. |
| sum(expr) | Сalculates the sum of `expr`. |

## Aggregate Function Builder

Import trait `AggregateUDFExprBuilder` and update the arguments directly in `Expr`

See datafusion-examples/examples/udaf_expr.rs for example usage.

| Syntax | Equivalent to |
| ----------------------------------------------------------------------- | ----------------------------------- |
| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build().unwrap() | first_value(expr, Some(vec![expr])) |

## Subquery Expressions

| Syntax | Description |
Expand Down
Loading