Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FirstValue an UDAF, Change AggregateUDFImpl::accumulator signature, support ORDER BY for UDAFs #9874

Merged
merged 49 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b94f70f
first draft
jayzhan211 Feb 16, 2024
c743d13
clippy fix
jayzhan211 Feb 18, 2024
3a7e965
cleanup
jayzhan211 Feb 18, 2024
4917f56
use one vector for ordering req
jayzhan211 Feb 21, 2024
c9e8641
add sort exprs to accumulator
jayzhan211 Feb 21, 2024
3a5f0d1
clippy
jayzhan211 Feb 21, 2024
a3ea00a
cleanup
jayzhan211 Feb 21, 2024
f349f21
fix doc test
jayzhan211 Feb 21, 2024
6fcdaac
change to ref
jayzhan211 Feb 27, 2024
c3512a6
fix typo
jayzhan211 Feb 27, 2024
092d46e
fix doc
jayzhan211 Feb 27, 2024
8592e6b
fmt
jayzhan211 Mar 1, 2024
0f8fc24
move schema and logical ordering exprs
jayzhan211 Mar 1, 2024
3185f9f
remove redudant info
jayzhan211 Mar 1, 2024
3ecc772
rename
jayzhan211 Mar 1, 2024
faadc63
cleanup
jayzhan211 Mar 1, 2024
7e33910
add ignore nulls
jayzhan211 Mar 7, 2024
cfffcbf
Merge remote-tracking branch 'upstream/main' into udf-order-2
jayzhan211 Mar 25, 2024
6aaa15c
fix conflict
jayzhan211 Mar 25, 2024
b74b7d2
backup
jayzhan211 Mar 26, 2024
263e6cb
complete return_type
jayzhan211 Mar 26, 2024
0a77e4f
complete replace
jayzhan211 Mar 30, 2024
7b26377
split to first value udf
jayzhan211 Mar 30, 2024
4bfd91d
replace accumulator
jayzhan211 Mar 30, 2024
7f54141
fmt
jayzhan211 Mar 30, 2024
6339535
cleanup
jayzhan211 Mar 30, 2024
33ae6ee
small fix
jayzhan211 Mar 30, 2024
b4eb865
remove ordering types
jayzhan211 Mar 30, 2024
d8ab6c5
make state fields more flexible
jayzhan211 Mar 30, 2024
a3bff42
cleanup
jayzhan211 Mar 30, 2024
53465fd
replace done
jayzhan211 Mar 30, 2024
cc21496
cleanup
jayzhan211 Mar 30, 2024
b62544f
cleanup
jayzhan211 Mar 30, 2024
ddfabad
Merge remote-tracking branch 'upstream/main' into first-value-udf
jayzhan211 Mar 30, 2024
4b809b0
rm comments
jayzhan211 Mar 30, 2024
2534727
cleanup
jayzhan211 Mar 30, 2024
17378dd
rm test1
jayzhan211 Mar 30, 2024
dd1c4ba
fix state fields
jayzhan211 Mar 31, 2024
5d5d310
fmt
jayzhan211 Mar 31, 2024
23f20f9
args struct for accumulator
jayzhan211 Mar 31, 2024
b2ba8c3
simplify
jayzhan211 Mar 31, 2024
75aa2fe
add sig
jayzhan211 Mar 31, 2024
5b9625f
add comments
jayzhan211 Mar 31, 2024
d5c3f6f
fmt
jayzhan211 Mar 31, 2024
dc9549a
fix docs
jayzhan211 Apr 1, 2024
7ce3d41
Merge remote-tracking branch 'upstream/main' into first-value-udf
jayzhan211 Apr 1, 2024
49b4a76
use exprs utils
jayzhan211 Apr 1, 2024
d70cce5
rm state type
jayzhan211 Apr 2, 2024
29c4018
add comment
jayzhan211 Apr 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
Expand All @@ -30,7 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -85,13 +87,21 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about the impact on this API for UDAF writers last night.

Specifically, about the many existing UDAFs that exist / will exist at the time this change gets released and on the first time people encounter / try to use this API. i think the args with datatypes is much easier to use (and has less mental gymnastics to use). Thus I am going to propose an easier / beginner API for this that will require fewer changes to existing UDAFs and will be easier to use for first timers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what i came up with: #9920

Ok(Box::new(GeometricMean::new()))
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
Ok(vec![DataType::Float64, DataType::UInt32])
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("n", DataType::UInt32, true),
])
}

/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
Expand Down Expand Up @@ -191,7 +201,6 @@ impl Accumulator for GeometricMean {

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ use datafusion_common::{
OwnedTableReference, SchemaReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::{create_first_value, Signature, Volatility};
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
var_provider::is_system_variables,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
use datafusion_physical_expr::create_first_value_accumulator;
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel},
Expand All @@ -82,6 +85,7 @@ use datafusion_sql::{

use async_trait::async_trait;
use chrono::{DateTime, Utc};
use log::debug;
use parking_lot::RwLock;
use sqlparser::dialect::dialect_from_str;
use url::Url;
Expand Down Expand Up @@ -1457,6 +1461,22 @@ impl SessionState {
datafusion_functions_array::register_all(&mut new_self)
.expect("can not register array expressions");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also panic if register fails here


let first_value = create_first_value(
"FIRST_VALUE",
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
Arc::new(create_first_value_accumulator),
);

match new_self.register_udaf(Arc::new(first_value)) {
Ok(Some(existing_udaf)) => {
debug!("Overwrite existing UDAF: {}", existing_udaf.name());
}
Ok(None) => {}
Err(err) => {
panic!("Failed to register UDAF: {}", err);
}
}

new_self
}
/// Returns new [`SessionState`] using the provided
Expand Down
50 changes: 31 additions & 19 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
distinct,
args,
filter,
order_by,
order_by: _,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
}
AggregateFunctionDefinition::UDF(fun) => {
// TODO: Add support for filter and order by in AggregateUDF
// TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}
if order_by.is_some() {
return exec_err!(
"aggregate expression with order_by is not supported"
);
}

let names = args
.iter()
.map(|e| create_physical_name(e, false))
Expand Down Expand Up @@ -1667,20 +1663,22 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
)?),
None => None,
};
let order_by = match order_by {
Some(e) => Some(create_physical_sort_exprs(
e,
logical_input_schema,
execution_props,
)?),
None => None,
};

let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
Expand All @@ -1690,16 +1688,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
name,
ignore_nulls,
)?;
(agg_expr, filter, order_by)
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::UDF(fun) => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
);
(agg_expr?, filter, order_by)
ignore_nulls,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
Expand Down
20 changes: 11 additions & 9 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -491,7 +492,7 @@ impl TimeSum {
// Returns the same type as its input
let return_type = timestamp_type.clone();

let state_type = vec![timestamp_type.clone()];
let state_fields = vec![Field::new("sum", timestamp_type, true)];

let volatility = Volatility::Immutable;

Expand All @@ -505,7 +506,7 @@ impl TimeSum {
return_type,
volatility,
accumulator,
state_type,
state_fields,
));

// register the selector as "time_sum"
Expand Down Expand Up @@ -591,6 +592,11 @@ impl FirstSelector {
fn register(ctx: &mut SessionContext) {
let return_type = Self::output_datatype();
let state_type = Self::state_datatypes();
let state_fields = state_type
.into_iter()
.enumerate()
.map(|(i, t)| Field::new(format!("{i}"), t, true))
.collect::<Vec<_>>();

// Possible input signatures
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
Expand All @@ -607,7 +613,7 @@ impl FirstSelector {
Signature::one_of(signatures, volatility),
return_type,
accumulator,
state_type,
state_fields,
));

// register the selector as "first"
Expand Down Expand Up @@ -717,15 +723,11 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(DataType::UInt64)
}

fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
}

fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
Ok(vec![DataType::UInt64])
}

fn groups_accumulator_supported(&self) -> bool {
true
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,15 @@ impl AggregateFunction {
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
distinct,
filter,
order_by,
null_treatment: None,
null_treatment,
}
}
}
Expand Down
Loading
Loading