Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDAF: Add more fields to state fields #10391

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -92,14 +92,9 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("prod", args.input_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}
Expand Down
8 changes: 2 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::expr::{
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery,
Expand Down Expand Up @@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
Expand Down
39 changes: 38 additions & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;

Expand Down Expand Up @@ -84,6 +84,43 @@ impl<'a> AccumulatorArgs<'a> {
}
}

/// `StateFieldsArgs` encapsulates details regarding the required state fields for an aggregate function.
///
/// - `name`: Name of the aggregate function.
/// - `input_type`: Input type of the aggregate function.
/// - `return_type`: Return type of the aggregate function. Defined by `fn return_type` in AggregateUDFImpl.
/// - `ordering_fields`: Fields utilized for functions sensitive to ordering.
/// - `order_by_data_types`: Data types for the ordering fields.
/// - `nullable`: Indicates whether the state fields can be null.
pub struct StateFieldsArgs<'a> {
pub name: &'a str,
pub input_type: &'a DataType,
pub return_type: &'a DataType,
pub ordering_fields: &'a [Field],
pub order_by_data_types: &'a [DataType],
pub nullable: bool,
}

impl<'a> StateFieldsArgs<'a> {
pub fn new(
name: &'a str,
input_type: &'a DataType,
return_type: &'a DataType,
ordering_fields: &'a [Field],
order_by_data_types: &'a [DataType],
nullable: bool,
) -> Self {
Self {
name,
input_type,
return_type,
ordering_fields,
order_by_data_types,
nullable,
}
}
}

/// Factory that returns an accumulator for the given aggregate function.
pub type AccumulatorFactoryFunction =
Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;
Expand Down
34 changes: 9 additions & 25 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::function::AccumulatorArgs;
use crate::function::{AccumulatorArgs, StateFieldsArgs};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -177,13 +175,8 @@ impl AggregateUDF {
/// for more details.
///
/// This is used to support multi-phase aggregations
pub fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.inner.state_fields(name, value_type, ordering_fields)
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
Expand Down Expand Up @@ -291,14 +284,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
///
/// # Arguments:
/// 1. `name`: the name of the expression (e.g. AVG, SUM, etc)
/// 2. `value_type`: Aggregate's aggregate's output (returned by [`Self::return_type`])
/// 2. `input_type`: Aggregate's aggregate's output (returned by [`Self::return_type`])
/// 3. `ordering_fields`: the fields used to order the input arguments, if any.
/// Empty if no ordering expression is provided.
///
/// # Notes:
///
/// The default implementation returns a single state field named `name`
/// with the same type as `value_type`. This is suitable for aggregates such
/// with the same type as `input_type`. This is suitable for aggregates such
/// as `SUM` or `MIN` where partial state can be combined by applying the
/// same aggregate.
///
Expand All @@ -309,19 +302,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// The name of the fields must be unique within the query and thus should
/// be derived from `name`. See [`format_state_name`] for a utility function
/// to generate a unique name.
fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
let value_fields = vec![Field::new(
format_state_name(name, "value"),
value_type,
true,
)];

Ok(value_fields.into_iter().chain(ordering_fields).collect())
///
/// [`format_state_name`]: crate::utils::format_state_name
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
not_impl_err!("state_fields hasn't been implemented for {self:?} yet")
Copy link
Contributor Author

@jayzhan211 jayzhan211 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found most of the state_fields have their own version, so return err if not defined now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be an API change -- I think if we want to change it we should also update the documentation to reflect the change

}

/// If the aggregate expression has a specialized
Expand Down
33 changes: 21 additions & 12 deletions datafusion/functions-aggregate/src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ use datafusion_common::{
ScalarValue,
};
use datafusion_expr::{
function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility,
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::stats::StatsType;

Expand Down Expand Up @@ -93,18 +95,25 @@ impl AggregateUDFImpl for CovarianceSample {
Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
Field::new(
format_state_name(name, "algo_const"),
format_state_name(args.name, "count"),
DataType::UInt64,
args.nullable,
),
Field::new(
format_state_name(args.name, "mean1"),
DataType::Float64,
args.nullable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't some of these fields nullable even if their input is not nullable (e.g. mean with no inputs is null)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this should be true

),
Field::new(
format_state_name(args.name, "mean2"),
DataType::Float64,
args.nullable,
),
Field::new(
format_state_name(args.name, "algo_const"),
DataType::Float64,
true,
),
Expand Down
15 changes: 5 additions & 10 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Expand Down Expand Up @@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue {
.map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
}

fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new(
format_state_name(name, "first_value"),
value_type,
format_state_name(args.name, "first_value"),
args.input_type.clone(),
true,
)];
fields.extend(ordering_fields);
fields.extend(args.ordering_fields.to_vec());
fields.push(Field::new("is_set", DataType::Boolean, true));
Ok(fields)
}
Expand Down
40 changes: 28 additions & 12 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
Expand Down Expand Up @@ -54,23 +55,27 @@ pub fn create_aggregate_expr(
&fun.signature().type_signature,
)?;

let ordering_types = ordering_req
let order_by_data_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
.collect::<Result<Vec<_>>>()?;

let ordering_fields = ordering_fields(ordering_req, &ordering_types);
let ordering_fields = ordering_fields(ordering_req, &order_by_data_types);
let nullable = input_phy_exprs[0].nullable(schema)?;

Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
input_type: input_exprs_types[0].clone(),
return_type: fun.return_type(&input_exprs_types)?,
name: name.into(),
schema: schema.clone(),
sort_exprs: sort_exprs.to_vec(),
order_by_data_types,
ordering_req: ordering_req.to_vec(),
ignore_nulls,
ordering_fields,
nullable,
}))
}

Expand Down Expand Up @@ -152,16 +157,21 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
/// input type
input_type: DataType,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found most of the builtin function name is input_data_type so I rename it.

/// Output / return type of this aggregate
data_type: DataType,
return_type: DataType,
name: String,
schema: Schema,
// The logical order by expressions
sort_exprs: Vec<Expr>,
// The physical order by expressions
ordering_req: LexOrdering,
// The data types of the order by expressions
order_by_data_types: Vec<DataType>,
ignore_nulls: bool,
ordering_fields: Vec<Field>,
nullable: bool,
}

impl AggregateFunctionExpr {
Expand All @@ -182,20 +192,26 @@ impl AggregateExpr for AggregateFunctionExpr {
}

fn state_fields(&self) -> Result<Vec<Field>> {
self.fun.state_fields(
self.name(),
self.data_type.clone(),
self.ordering_fields.clone(),
)
let args = StateFieldsArgs::new(
&self.name,
&self.input_type,
&self.return_type,
&self.ordering_fields,
&self.order_by_data_types,
self.nullable,
);

self.fun.state_fields(args)
}

// TODO: Add field function in AggregateUDFImpl
fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
Ok(Field::new(&self.name, self.input_type.clone(), true))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let acc_args = AccumulatorArgs::new(
&self.data_type,
&self.input_type,
&self.schema,
self.ignore_nulls,
&self.sort_exprs,
Expand Down Expand Up @@ -282,7 +298,7 @@ impl PartialEq<dyn Any> for AggregateFunctionExpr {
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.data_type == x.data_type
&& self.input_type == x.input_type
&& self.fun == x.fun
&& self.args.len() == x.args.len()
&& self
Expand Down
Loading