-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UDAF: Add more fields to state fields #10391
Changes from all commits
4c243c2
d252ff1
49e9a2a
994ecb8
b840a50
5eb43ea
f8310ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
|
||
//! [`AggregateUDF`]: User Defined Aggregate Functions | ||
|
||
use crate::function::AccumulatorArgs; | ||
use crate::function::{AccumulatorArgs, FieldArgs, StateFieldsArgs}; | ||
use crate::groups_accumulator::GroupsAccumulator; | ||
use crate::utils::format_state_name; | ||
use crate::{Accumulator, Expr}; | ||
|
@@ -27,7 +27,6 @@ use datafusion_common::{not_impl_err, Result}; | |
use std::any::Any; | ||
use std::fmt::{self, Debug, Formatter}; | ||
use std::sync::Arc; | ||
use std::vec; | ||
|
||
/// Logical representation of a user-defined [aggregate function] (UDAF). | ||
/// | ||
|
@@ -172,18 +171,18 @@ impl AggregateUDF { | |
self.inner.accumulator(acc_args) | ||
} | ||
|
||
/// See [`AggregateUDFImpl::field`] for more details. | ||
pub fn field(&self, args: FieldArgs) -> Result<Field> { | ||
self.inner.field(args) | ||
} | ||
|
||
/// Return the fields used to store the intermediate state for this aggregator, given | ||
/// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`] | ||
/// for more details. | ||
/// | ||
/// This is used to support multi-phase aggregations | ||
pub fn state_fields( | ||
&self, | ||
name: &str, | ||
value_type: DataType, | ||
ordering_fields: Vec<Field>, | ||
) -> Result<Vec<Field>> { | ||
self.inner.state_fields(name, value_type, ordering_fields) | ||
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
self.inner.state_fields(args) | ||
} | ||
|
||
/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. | ||
|
@@ -222,7 +221,7 @@ where | |
/// # use arrow::datatypes::DataType; | ||
/// # use datafusion_common::{DataFusionError, plan_err, Result}; | ||
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; | ||
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; | ||
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs, function::StateFieldsArgs}; | ||
/// # use arrow::datatypes::Schema; | ||
/// # use arrow::datatypes::Field; | ||
/// #[derive(Debug, Clone)] | ||
|
@@ -251,9 +250,9 @@ where | |
/// } | ||
/// // This is the accumulator factory; DataFusion uses it to create new accumulators. | ||
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() } | ||
/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec<Field>) -> Result<Vec<Field>> { | ||
/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
/// Ok(vec![ | ||
/// Field::new("value", value_type, true), | ||
/// Field::new("value", args.return_type.clone(), true), | ||
/// Field::new("ordering", DataType::UInt32, true) | ||
/// ]) | ||
/// } | ||
|
@@ -287,18 +286,26 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { | |
/// aggregate function was called. | ||
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>; | ||
|
||
/// Return the fields for the function | ||
fn field(&self, _args: FieldArgs) -> Result<Field> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the same information as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
not_impl_err!("field hasn't been implemented for {self:?} yet") | ||
} | ||
|
||
/// Return the fields used to store the intermediate state of this accumulator. | ||
/// | ||
/// # Arguments: | ||
/// 1. `name`: the name of the expression (e.g. AVG, SUM, etc) | ||
/// 2. `value_type`: Aggregate's aggregate's output (returned by [`Self::return_type`]) | ||
/// 3. `ordering_fields`: the fields used to order the input arguments, if any. | ||
/// - `name`: the name of the expression (e.g. AVG(args...), SUM(args...), etc) | ||
/// - `input_type`: the input type of the aggregate function | ||
/// - `return_type`: the return type of the aggregate function (returned by [`Self::return_type`]) | ||
/// - `ordering_fields`: the fields used to order the input arguments, if any. | ||
/// Empty if no ordering expression is provided. | ||
/// - `order_by_data_types`: the data types of the ordering fields. | ||
/// - `nullable`: whether the field can be null. | ||
/// | ||
/// # Notes: | ||
/// | ||
/// The default implementation returns a single state field named `name` | ||
/// with the same type as `value_type`. This is suitable for aggregates such | ||
/// with the same type as `input_type`. This is suitable for aggregates such | ||
/// as `SUM` or `MIN` where partial state can be combined by applying the | ||
/// same aggregate. | ||
/// | ||
|
@@ -309,19 +316,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { | |
/// The name of the fields must be unique within the query and thus should | ||
/// be derived from `name`. See [`format_state_name`] for a utility function | ||
/// to generate a unique name. | ||
fn state_fields( | ||
&self, | ||
name: &str, | ||
value_type: DataType, | ||
ordering_fields: Vec<Field>, | ||
) -> Result<Vec<Field>> { | ||
let value_fields = vec![Field::new( | ||
format_state_name(name, "value"), | ||
value_type, | ||
true, | ||
)]; | ||
|
||
Ok(value_fields.into_iter().chain(ordering_fields).collect()) | ||
/// | ||
/// [`format_state_name`]: crate::utils::format_state_name | ||
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
not_impl_err!("state_fields hasn't been implemented for {self:?} yet") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found most of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be an API change -- I think if we want to change it we should also update the documentation to reflect the change |
||
} | ||
|
||
/// If the aggregate expression has a specialized | ||
|
@@ -385,6 +383,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { | |
self.inner.name() | ||
} | ||
|
||
fn field(&self, args: FieldArgs) -> Result<Field> { | ||
Ok(Field::new(args.name, args.return_type.clone(), true)) | ||
} | ||
|
||
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
Ok(vec![Field::new( | ||
format_state_name(args.name, "aliased_aggregate_state"), | ||
args.return_type.clone(), | ||
true, | ||
)]) | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
self.inner.signature() | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given all the fields are pub, I wonder how much benefit this API adds over simply creating FieldArgs directly 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But like
name
, it includes arguments, and not every time is the same, how do they get the correct field withoutargs.name
?