Skip to content

Commit

Permalink
fix: impl ordering for serialization/deserialization for AggregateUdf (
Browse files Browse the repository at this point in the history
…#11926)

* fix: support ordering and pencentile function ser/der

* add more test case
  • Loading branch information
haohuaijin authored Aug 12, 2024
1 parent 5251dc9 commit 032b9c9
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 22 deletions.
1 change: 0 additions & 1 deletion datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ pub fn bounded_window_exec(
"count".to_owned(),
&[col(col_name, &schema).unwrap()],
&[],
&[],
&sort_exprs,
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
Expand Down
1 change: 0 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,6 @@ pub fn create_window_expr_with_name(
fun,
name,
&physical_args,
args,
&partition_by,
&order_by,
window_frame,
Expand Down
4 changes: 0 additions & 4 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> {

let partitionby_exprs = vec![];
let orderby_exprs = vec![];
let logical_exprs = vec![];
// Window frame starts with "UNBOUNDED PRECEDING":
let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));

Expand Down Expand Up @@ -285,7 +284,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
&window_fn,
fn_name.to_string(),
&args,
&logical_exprs,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame),
Expand Down Expand Up @@ -674,7 +672,6 @@ async fn run_window_test(
&window_fn,
fn_name.clone(),
&args,
&[],
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
Expand All @@ -693,7 +690,6 @@ async fn run_window_test(
&window_fn,
fn_name,
&args,
&[],
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ mod tests {
RecordBatchStream, SendableRecordBatchStream, TaskContext,
};
use datafusion_expr::{
Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::{col, Column, NthValue};
Expand Down Expand Up @@ -1303,10 +1303,7 @@ mod tests {
let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
let col_expr =
Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
let log_expr =
Expr::Column(datafusion_common::Column::from(schema.fields[0].name()));
let args = vec![col_expr];
let log_args = vec![log_expr];
let partitionby_exprs = vec![col(hash, &schema)?];
let orderby_exprs = vec![PhysicalSortExpr {
expr: col(order_by, &schema)?,
Expand All @@ -1327,7 +1324,6 @@ mod tests {
&window_fn,
fn_name,
&args,
&log_args,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
Expand Down
6 changes: 2 additions & 4 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use arrow::datatypes::Schema;
use arrow_schema::{DataType, Field, SchemaRef};
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::{
BuiltInWindowFunction, Expr, PartitionEvaluator, WindowFrame,
WindowFunctionDefinition, WindowUDF,
BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition,
WindowUDF,
};
use datafusion_physical_expr::equivalence::collapse_lex_req;
use datafusion_physical_expr::{
Expand Down Expand Up @@ -94,7 +94,6 @@ pub fn create_window_expr(
fun: &WindowFunctionDefinition,
name: String,
args: &[Arc<dyn PhysicalExpr>],
_logical_args: &[Expr],
partition_by: &[Arc<dyn PhysicalExpr>],
order_by: &[PhysicalSortExpr],
window_frame: Arc<WindowFrame>,
Expand Down Expand Up @@ -746,7 +745,6 @@ mod tests {
&[col("a", &schema)?],
&[],
&[],
&[],
Arc::new(WindowFrame::new(None)),
schema.as_ref(),
false,
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,10 @@ pub fn parse_physical_window_expr(
// TODO: Remove extended_schema if functions are all UDAF
let extended_schema =
schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?;
// approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
let logical_exprs = &[];
create_window_expr(
&fun,
name,
&window_node_expr,
logical_exprs,
&partition_by,
&order_by,
Arc::new(window_frame),
Expand Down
6 changes: 2 additions & 4 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
ExprType::AggregateExpr(agg_node) => {
let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> = agg_node.expr.iter()
.map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
let _ordering_req: Vec<PhysicalSortExpr> = agg_node.ordering_req.iter()
let ordering_req: Vec<PhysicalSortExpr> = agg_node.ordering_req.iter()
.map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
agg_node.aggregate_function.as_ref().map(|func| {
match func {
Expand All @@ -487,14 +487,12 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
None => registry.udaf(udaf_name)?
};

// TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
// TODO: `order by` is not supported for UDAF yet
// https://github.com/apache/datafusion/issues/11804
AggregateExprBuilder::new(agg_udf, input_phy_expr)
.schema(Arc::clone(&physical_schema))
.alias(name)
.with_ignore_nulls(agg_node.ignore_nulls)
.with_distinct(agg_node.distinct)
.order_by(ordering_req)
.build()
}
}
Expand Down
66 changes: 66 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use std::vec;
use arrow::array::RecordBatch;
use arrow::csv::WriterBuilder;
use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder;
use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf;
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_functions_aggregate::min_max::max_udaf;
use prost::Message;

Expand Down Expand Up @@ -412,6 +414,70 @@ fn rountrip_aggregate_with_limit() -> Result<()> {
roundtrip_test(Arc::new(agg))
}

#[test]
fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
let field_b = Field::new("b", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));

let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![AggregateExprBuilder::new(
approx_percentile_cont_udaf(),
vec![col("b", &schema)?, lit(0.5)],
)
.schema(Arc::clone(&schema))
.alias("APPROX_PERCENTILE_CONT(b, 0.5)")
.build()?];

let agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new_single(groups.clone()),
aggregates.clone(),
vec![None],
Arc::new(EmptyExec::new(schema.clone())),
schema,
)?;
roundtrip_test(Arc::new(agg))
}

#[test]
fn rountrip_aggregate_with_sort() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
let field_b = Field::new("b", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));

let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];
let sort_exprs = vec![PhysicalSortExpr {
expr: col("b", &schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];

let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![
AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("ARRAY_AGG(b)")
.order_by(sort_exprs)
.build()?,
];

let agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new_single(groups.clone()),
aggregates.clone(),
vec![None],
Arc::new(EmptyExec::new(schema.clone())),
schema,
)?;
roundtrip_test(Arc::new(agg))
}

#[test]
fn roundtrip_aggregate_udaf() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
Expand Down

0 comments on commit 032b9c9

Please sign in to comment.