Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDAF: Add more fields to state fields #10391

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
function::{AccumulatorArgs, FieldArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -81,6 +81,10 @@ impl AggregateUDFImpl for GeoMeanUdaf {
Ok(DataType::Float64)
}

fn field(&self, args: FieldArgs) -> Result<Field> {
Ok(Field::new(args.name, args.return_type.clone(), true))
}

/// This is the accumulator factory; DataFusion uses it to create new accumulators.
///
/// This is the accumulator factory for row wise accumulation; Even when `GroupsAccumulator`
Expand All @@ -92,14 +96,9 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("prod", args.return_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}
Expand Down
18 changes: 16 additions & 2 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
create_udaf,
function::{AccumulatorArgs, FieldArgs, StateFieldsArgs},
utils::format_state_name,
AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -716,6 +718,18 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
&self.signature
}

fn field(&self, args: FieldArgs) -> Result<Field> {
Ok(Field::new(args.name, args.return_type.clone(), true))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new(
format_state_name(args.name, "aliased_aggregate_state"),
args.return_type.clone(),
true,
)])
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::UInt64)
}
Expand Down
14 changes: 7 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use crate::expr::{
Placeholder, TryCast,
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
AccumulatorArgs, AccumulatorFactoryFunction, FieldArgs, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery,
Expand Down Expand Up @@ -690,12 +691,11 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn field(&self, args: FieldArgs) -> Result<Field> {
Ok(Field::new(args.name, args.return_type.clone(), true))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
Expand Down
68 changes: 67 additions & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;

Expand Down Expand Up @@ -84,6 +84,72 @@ impl<'a> AccumulatorArgs<'a> {
}
}

/// `StateFieldsArgs` encapsulates details regarding the required state fields for an aggregate function.
///
/// - `name`: Name of the aggregate function.
/// - `input_type`: Input type of the aggregate function.
/// - `return_type`: Return type of the aggregate function. Defined by `fn return_type` in AggregateUDFImpl.
/// - `nullable`: Indicates whether the field can be null.
pub struct FieldArgs<'a> {
pub name: &'a str,
pub input_type: &'a DataType,
pub return_type: &'a DataType,
pub nullable: bool,
}

impl<'a> FieldArgs<'a> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given all the fields are pub, I wonder how much benefit this API adds over simply creating FieldArgs directly 🤔

Copy link
Contributor Author

@jayzhan211 jayzhan211 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But like name, it includes arguments, and not every time is the same, how do they get the correct field without args.name?

pub fn new(
name: &'a str,
input_type: &'a DataType,
return_type: &'a DataType,
nullable: bool,
) -> Self {
Self {
name,
input_type,
return_type,
nullable,
}
}
}

/// `StateFieldsArgs` encapsulates details regarding the required state fields for an aggregate function.
///
/// - `name`: Name of the aggregate function.
/// - `input_type`: Input type of the aggregate function.
/// - `return_type`: Return type of the aggregate function. Defined by `fn return_type` in AggregateUDFImpl.
/// - `ordering_fields`: Fields utilized for functions sensitive to ordering.
/// - `order_by_data_types`: Data types for the ordering fields.
/// - `nullable`: Indicates whether the state fields can be null.
pub struct StateFieldsArgs<'a> {
pub name: &'a str,
pub input_type: &'a DataType,
pub return_type: &'a DataType,
pub ordering_fields: &'a [Field],
pub order_by_data_types: &'a [DataType],
pub nullable: bool,
}

impl<'a> StateFieldsArgs<'a> {
pub fn new(
name: &'a str,
input_type: &'a DataType,
return_type: &'a DataType,
ordering_fields: &'a [Field],
order_by_data_types: &'a [DataType],
nullable: bool,
) -> Self {
Self {
name,
input_type,
return_type,
ordering_fields,
order_by_data_types,
nullable,
}
}
}

/// Factory that returns an accumulator for the given aggregate function.
pub type AccumulatorFactoryFunction =
Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;
Expand Down
68 changes: 39 additions & 29 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::function::AccumulatorArgs;
use crate::function::{AccumulatorArgs, FieldArgs, StateFieldsArgs};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
Expand All @@ -27,7 +27,6 @@ use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -172,18 +171,18 @@ impl AggregateUDF {
self.inner.accumulator(acc_args)
}

/// See [`AggregateUDFImpl::field`] for more details.
pub fn field(&self, args: FieldArgs) -> Result<Field> {
self.inner.field(args)
}

/// Return the fields used to store the intermediate state for this aggregator, given
/// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
/// for more details.
///
/// This is used to support multi-phase aggregations
pub fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.inner.state_fields(name, value_type, ordering_fields)
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
Expand Down Expand Up @@ -222,7 +221,7 @@ where
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs, function::StateFieldsArgs};
/// # use arrow::datatypes::Schema;
/// # use arrow::datatypes::Field;
/// #[derive(Debug, Clone)]
Expand Down Expand Up @@ -251,9 +250,9 @@ where
/// }
/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
/// Ok(vec![
/// Field::new("value", value_type, true),
/// Field::new("value", args.return_type.clone(), true),
/// Field::new("ordering", DataType::UInt32, true)
/// ])
/// }
Expand Down Expand Up @@ -287,18 +286,26 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// aggregate function was called.
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;

/// Return the fields for the function
fn field(&self, _args: FieldArgs) -> Result<Field> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same information as return_type? If so, perhaps we should deprecate the return_type function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_type is part of the information in field and state_field not replaceable to field

not_impl_err!("field hasn't been implemented for {self:?} yet")
}

/// Return the fields used to store the intermediate state of this accumulator.
///
/// # Arguments:
/// 1. `name`: the name of the expression (e.g. AVG, SUM, etc)
/// 2. `value_type`: Aggregate's aggregate's output (returned by [`Self::return_type`])
/// 3. `ordering_fields`: the fields used to order the input arguments, if any.
/// - `name`: the name of the expression (e.g. AVG(args...), SUM(args...), etc)
/// - `input_type`: the input type of the aggregate function
/// - `return_type`: the return type of the aggregate function (returned by [`Self::return_type`])
/// - `ordering_fields`: the fields used to order the input arguments, if any.
/// Empty if no ordering expression is provided.
/// - `order_by_data_types`: the data types of the ordering fields.
/// - `nullable`: whether the field can be null.
///
/// # Notes:
///
/// The default implementation returns a single state field named `name`
/// with the same type as `value_type`. This is suitable for aggregates such
/// with the same type as `input_type`. This is suitable for aggregates such
/// as `SUM` or `MIN` where partial state can be combined by applying the
/// same aggregate.
///
Expand All @@ -309,19 +316,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// The name of the fields must be unique within the query and thus should
/// be derived from `name`. See [`format_state_name`] for a utility function
/// to generate a unique name.
fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
let value_fields = vec![Field::new(
format_state_name(name, "value"),
value_type,
true,
)];

Ok(value_fields.into_iter().chain(ordering_fields).collect())
///
/// [`format_state_name`]: crate::utils::format_state_name
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
not_impl_err!("state_fields hasn't been implemented for {self:?} yet")
Copy link
Contributor Author

@jayzhan211 jayzhan211 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found most of the state_fields have their own version, so return err if not defined now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be an API change -- I think if we want to change it we should also update the documentation to reflect the change

}

/// If the aggregate expression has a specialized
Expand Down Expand Up @@ -385,6 +383,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
self.inner.name()
}

fn field(&self, args: FieldArgs) -> Result<Field> {
Ok(Field::new(args.name, args.return_type.clone(), true))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new(
format_state_name(args.name, "aliased_aggregate_state"),
args.return_type.clone(),
true,
)])
}

fn signature(&self) -> &Signature {
self.inner.signature()
}
Expand Down
37 changes: 25 additions & 12 deletions datafusion/functions-aggregate/src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ use datafusion_common::{
ScalarValue,
};
use datafusion_expr::{
function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility,
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::stats::StatsType;

Expand Down Expand Up @@ -93,18 +95,29 @@ impl AggregateUDFImpl for CovarianceSample {
Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn field(&self, args: datafusion_expr::function::FieldArgs) -> Result<Field> {
Ok(Field::new(args.name, args.return_type.clone(), true))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
Field::new(
format_state_name(name, "algo_const"),
format_state_name(args.name, "count"),
DataType::UInt64,
true,
),
Field::new(
format_state_name(args.name, "mean1"),
DataType::Float64,
true,
),
Field::new(
format_state_name(args.name, "mean2"),
DataType::Float64,
true,
),
Field::new(
format_state_name(args.name, "algo_const"),
DataType::Float64,
true,
),
Expand Down
Loading