-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UDAF: Add more fields to state fields #10391
Changes from 1 commit
4c243c2
d252ff1
49e9a2a
994ecb8
b840a50
5eb43ea
f8310ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,8 +30,10 @@ use datafusion_common::{ | |
ScalarValue, | ||
}; | ||
use datafusion_expr::{ | ||
function::AccumulatorArgs, type_coercion::aggregates::NUMERICS, | ||
utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, | ||
function::{AccumulatorArgs, StateFieldsArgs}, | ||
type_coercion::aggregates::NUMERICS, | ||
utils::format_state_name, | ||
Accumulator, AggregateUDFImpl, Signature, Volatility, | ||
}; | ||
use datafusion_physical_expr_common::aggregate::stats::StatsType; | ||
|
||
|
@@ -93,18 +95,25 @@ impl AggregateUDFImpl for CovarianceSample { | |
Ok(DataType::Float64) | ||
} | ||
|
||
fn state_fields( | ||
&self, | ||
name: &str, | ||
_value_type: DataType, | ||
_ordering_fields: Vec<Field>, | ||
) -> Result<Vec<Field>> { | ||
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
Ok(vec![ | ||
Field::new(format_state_name(name, "count"), DataType::UInt64, true), | ||
Field::new(format_state_name(name, "mean1"), DataType::Float64, true), | ||
Field::new(format_state_name(name, "mean2"), DataType::Float64, true), | ||
Field::new( | ||
format_state_name(name, "algo_const"), | ||
format_state_name(args.name, "count"), | ||
DataType::UInt64, | ||
args.nullable, | ||
), | ||
Field::new( | ||
format_state_name(args.name, "mean1"), | ||
DataType::Float64, | ||
args.nullable, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aren't some of these fields nullable even if their input is not nullable (e.g. mean with no inputs is null)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this should be |
||
), | ||
Field::new( | ||
format_state_name(args.name, "mean2"), | ||
DataType::Float64, | ||
args.nullable, | ||
), | ||
Field::new( | ||
format_state_name(args.name, "algo_const"), | ||
DataType::Float64, | ||
true, | ||
), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ pub mod utils; | |
|
||
use arrow::datatypes::{DataType, Field, Schema}; | ||
use datafusion_common::{not_impl_err, Result}; | ||
use datafusion_expr::function::StateFieldsArgs; | ||
use datafusion_expr::type_coercion::aggregates::check_arg_count; | ||
use datafusion_expr::{ | ||
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, | ||
|
@@ -54,23 +55,27 @@ pub fn create_aggregate_expr( | |
&fun.signature().type_signature, | ||
)?; | ||
|
||
let ordering_types = ordering_req | ||
let order_by_data_types = ordering_req | ||
.iter() | ||
.map(|e| e.expr.data_type(schema)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let ordering_fields = ordering_fields(ordering_req, &ordering_types); | ||
let ordering_fields = ordering_fields(ordering_req, &order_by_data_types); | ||
let nullable = input_phy_exprs[0].nullable(schema)?; | ||
|
||
Ok(Arc::new(AggregateFunctionExpr { | ||
fun: fun.clone(), | ||
args: input_phy_exprs.to_vec(), | ||
data_type: fun.return_type(&input_exprs_types)?, | ||
input_type: input_exprs_types[0].clone(), | ||
return_type: fun.return_type(&input_exprs_types)?, | ||
name: name.into(), | ||
schema: schema.clone(), | ||
sort_exprs: sort_exprs.to_vec(), | ||
order_by_data_types, | ||
ordering_req: ordering_req.to_vec(), | ||
ignore_nulls, | ||
ordering_fields, | ||
nullable, | ||
})) | ||
} | ||
|
||
|
@@ -152,16 +157,21 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> { | |
pub struct AggregateFunctionExpr { | ||
fun: AggregateUDF, | ||
args: Vec<Arc<dyn PhysicalExpr>>, | ||
/// input type | ||
input_type: DataType, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found most of the builtin function name is |
||
/// Output / return type of this aggregate | ||
data_type: DataType, | ||
return_type: DataType, | ||
name: String, | ||
schema: Schema, | ||
// The logical order by expressions | ||
sort_exprs: Vec<Expr>, | ||
// The physical order by expressions | ||
ordering_req: LexOrdering, | ||
// The data types of the order by expressions | ||
order_by_data_types: Vec<DataType>, | ||
ignore_nulls: bool, | ||
ordering_fields: Vec<Field>, | ||
nullable: bool, | ||
} | ||
|
||
impl AggregateFunctionExpr { | ||
|
@@ -182,20 +192,26 @@ impl AggregateExpr for AggregateFunctionExpr { | |
} | ||
|
||
fn state_fields(&self) -> Result<Vec<Field>> { | ||
self.fun.state_fields( | ||
self.name(), | ||
self.data_type.clone(), | ||
self.ordering_fields.clone(), | ||
) | ||
let args = StateFieldsArgs::new( | ||
&self.name, | ||
&self.input_type, | ||
&self.return_type, | ||
&self.ordering_fields, | ||
&self.order_by_data_types, | ||
self.nullable, | ||
); | ||
|
||
self.fun.state_fields(args) | ||
} | ||
|
||
// TODO: Add field function in AggregateUDFImpl | ||
fn field(&self) -> Result<Field> { | ||
Ok(Field::new(&self.name, self.data_type.clone(), true)) | ||
Ok(Field::new(&self.name, self.input_type.clone(), true)) | ||
} | ||
|
||
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
let acc_args = AccumulatorArgs::new( | ||
&self.data_type, | ||
&self.input_type, | ||
&self.schema, | ||
self.ignore_nulls, | ||
&self.sort_exprs, | ||
|
@@ -282,7 +298,7 @@ impl PartialEq<dyn Any> for AggregateFunctionExpr { | |
.downcast_ref::<Self>() | ||
.map(|x| { | ||
self.name == x.name | ||
&& self.data_type == x.data_type | ||
&& self.input_type == x.input_type | ||
&& self.fun == x.fun | ||
&& self.args.len() == x.args.len() | ||
&& self | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found most of the
state_fields
have their own version, so return err if not defined nowThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be an API change -- I think if we want to change it we should also update the documentation to reflect the change