Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Sum UDAF #10651

Merged
merged 50 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
192211f
move accumulate
jayzhan211 May 5, 2024
403ee1d
move prim_op
jayzhan211 May 5, 2024
5bcab35
move test to slt
jayzhan211 May 5, 2024
426f2ab
remove sum distinct
jayzhan211 May 5, 2024
3712e18
Merge branch 'move-sum-test' into sum-udaf
jayzhan211 May 5, 2024
e080d37
move sum aggregate
jayzhan211 May 5, 2024
2b6784c
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 6, 2024
64db75c
fix args
jayzhan211 May 8, 2024
c086a22
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 8, 2024
a9a5423
add sum
jayzhan211 May 8, 2024
201c90c
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 19, 2024
3e9e7e9
merge fix
jayzhan211 May 19, 2024
22540b2
fix sum sig
jayzhan211 May 19, 2024
d9229db
todo: wait ahash merge
jayzhan211 May 19, 2024
518894a
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 25, 2024
ee068db
rebase
jayzhan211 May 25, 2024
6224333
disable ordering req by default
jayzhan211 May 25, 2024
47ae11f
check arg count
jayzhan211 May 25, 2024
25dcb64
rm old workflow
jayzhan211 May 25, 2024
d16f1b1
fmt
jayzhan211 May 25, 2024
5381f2d
fix failed test
jayzhan211 May 25, 2024
b403331
doc and fmt
jayzhan211 May 25, 2024
78a70b3
check udaf first
jayzhan211 May 25, 2024
79019fe
fmt
jayzhan211 May 25, 2024
20e9f79
fix ci
jayzhan211 May 25, 2024
4f2f0ac
fix ci
jayzhan211 May 25, 2024
ca4b528
fix ci
jayzhan211 May 25, 2024
e6b021e
fix err msg AGAIN
jayzhan211 May 25, 2024
81dd68f
rm sum in builtin test which covered in sql
jayzhan211 May 25, 2024
ffb0a98
proto for window with udaf
jayzhan211 May 25, 2024
dafd1aa
fix slt
jayzhan211 May 25, 2024
70b8651
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 26, 2024
f6d37bf
fmt
jayzhan211 May 26, 2024
921dc00
fix err msg
jayzhan211 May 26, 2024
093fb24
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 27, 2024
f684f5d
fix exprfn
jayzhan211 May 27, 2024
c1e74f7
fix ciy
jayzhan211 May 27, 2024
6d3ef58
fix ci
jayzhan211 May 27, 2024
53c7bb1
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 30, 2024
02fd8a5
rename first/last to lowercase
jayzhan211 May 30, 2024
6c7ce04
skip sum
jayzhan211 May 30, 2024
5490bcf
fix firstvalue
jayzhan211 May 30, 2024
5f93beb
Merge remote-tracking branch 'upstream/main' into sum-udaf
jayzhan211 May 30, 2024
ff947bb
clippy
jayzhan211 May 30, 2024
2492ba7
add doc
jayzhan211 May 31, 2024
73573be
rm has_ordering_req
jayzhan211 May 31, 2024
f2b3732
default hard req
jayzhan211 May 31, 2024
2c0c52c
insensitive for sum
jayzhan211 May 31, 2024
62346dd
cleanup duplicate code
jayzhan211 Jun 3, 2024
a41fcc5
Re-introduce check
mustafasrepo Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ impl AggregateUDFImpl for GeoMeanUdaf {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(GeometricMeanGroupsAccumulator::new()))
}
}
Expand Down
6 changes: 5 additions & 1 deletion datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ impl AggregateUDFImpl for BetterAvgUdaf {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}

// we override method, to return new expression which would substitute
// user defined function call
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ use datafusion_expr::{
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null, sum};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::median;
use datafusion_functions_aggregate::expr_fn::sum;

use async_trait::async_trait;

Expand Down Expand Up @@ -1593,9 +1594,8 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
array_agg, cast, count_distinct, create_udf, expr, lit, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunctionDefinition,
array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition,
};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ mod tests {
assert_eq!(
evaluate_partition_prefix(
partitions,
&[col("a").eq(lit("foo")).and((col("b").eq(lit("bar"))))],
&[col("a").eq(lit("foo")).and(col("b").eq(lit("bar")))],
),
Some(Path::from("a=foo/b=bar")),
);
Expand Down
5 changes: 2 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2257,9 +2257,8 @@ mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, TableReference};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::{
col, lit, sum, LogicalPlanBuilder, UserDefinedLogicalNodeCore,
};
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::EquivalenceProperties;

fn make_session_state() -> SessionState {
Expand Down
1 change: 0 additions & 1 deletion datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ pub use datafusion_expr::{
Expr,
};
pub use datafusion_functions::expr_fn::*;
pub use datafusion_functions_aggregate::expr_fn::*;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr / sum is not removed in this PR, so I need to remove this to avoid import conflict

#[cfg(feature = "array_expressions")]
pub use datafusion_functions_array::expr_fn::*;

Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr,
ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::sum;

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
12 changes: 11 additions & 1 deletion datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;
Expand Down Expand Up @@ -341,7 +343,7 @@ fn get_random_function(
window_fn_map.insert(
"sum",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![arg.clone()],
),
);
Expand Down Expand Up @@ -468,6 +470,14 @@ fn get_random_function(
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
} else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
}

(window_fn.clone(), args, fn_name.to_string())
Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() {
let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
// Note if this query ever does start working
let err = execute(&ctx, sql).await.unwrap_err();
assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\", signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]), volatility: Immutable }, fun: \"<FUNC>\" } }(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING");
assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is certainly nicer

}

/// Basic query for with a udaf returning a structure
Expand Down Expand Up @@ -729,7 +729,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ async fn udaf_as_window_func() -> Result<()> {
context.register_udaf(my_acc);

let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table";
let expected = r#"Projection: my_table.a, AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
let expected = r#"Projection: my_table.a, my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
TableScan: my_table"#;

let dataframe = context.sql(sql).await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/built_in_window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ impl BuiltInWindowFunction {
Ntile => "NTILE",
Lag => "LAG",
Lead => "LEAD",
FirstValue => "FIRST_VALUE",
LastValue => "LAST_VALUE",
FirstValue => "first_value",
LastValue => "last_value",
NthValue => "NTH_VALUE",
}
}
Expand Down
18 changes: 13 additions & 5 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,14 @@ impl WindowFunctionDefinition {
impl fmt::Display for WindowFunctionDefinition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f),
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f),
WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f),
WindowFunctionDefinition::AggregateFunction(fun) => {
std::fmt::Display::fmt(fun, f)
}
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
std::fmt::Display::fmt(fun, f)
}
WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Display::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => std::fmt::Display::fmt(fun, f),
}
}
}
Expand Down Expand Up @@ -2263,7 +2267,11 @@ mod test {
let fun = find_df_window_func(name).unwrap();
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
assert_eq!(fun.to_string(), name.to_uppercase());
if fun.to_string() == "first_value" || fun.to_string() == "last_value" {
assert_eq!(fun.to_string(), name);
} else {
assert_eq!(fun.to_string(), name.to_uppercase());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, maybe we should treat udf names case insensitive way. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#10695 track the issue to rename name to lowercase

}
}
Ok(())
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ pub fn max(expr: Expr) -> Expr {
}

/// Create an expression to represent the sum() aggregate function
///
/// TODO: Remove this function and use `sum` from `datafusion_functions_aggregate::expr_fn` instead
pub fn sum(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Sum,
Expand Down
39 changes: 35 additions & 4 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ use crate::expr::{
InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
};
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::data_types_with_scalar_udf;
use crate::{utils, LogicalPlan, Projection, Subquery};
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf,
};
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
Expand Down Expand Up @@ -158,7 +160,25 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
match fun {
WindowFunctionDefinition::AggregateUDF(udf) => {
let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| {
plan_datafusion_err!(
"{} and {}",
err,
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&data_types
)
)
})?;
Ok(fun.return_type(&new_types)?)
}
Comment on lines +164 to +177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should bury this check, and conversion inside to the fun.return_type implementation for WindowFunctionDefinition::AggregateUDF not sure though.

Copy link
Contributor Author

@jayzhan211 jayzhan211 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer separate coerce_types and return_types given the difference between these two

_ => {
fun.return_type(&data_types)
}
}
}
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
let data_types = args
Expand All @@ -170,7 +190,18 @@ impl ExprSchemable for Expr {
fun.return_type(&data_types)
}
AggregateFunctionDefinition::UDF(fun) => {
Ok(fun.return_type(&data_types)?)
let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| {
plan_datafusion_err!(
"{} and {}",
err,
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&data_types
)
)
})?;
Comment on lines +193 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment above applies here

Ok(fun.return_type(&new_types)?)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ pub struct AccumulatorArgs<'a> {
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],

/// The name of the aggregate expression
pub name: &'a str,

/// Whether the aggregate function is distinct.
///
/// ```sql
Expand All @@ -82,9 +85,6 @@ pub struct AccumulatorArgs<'a> {

/// The number of arguments the aggregate function takes.
pub args_num: usize,

/// The name of the expression
pub name: &'a str,
}

/// [`StateFieldsArgs`] contains information about the fields that an
Expand Down
Loading