From dd2243ae616a23a49167eb4ba4db3b6982bd49a9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Dec 2021 10:55:26 -0800 Subject: [PATCH 1/7] Fix sort on aggregate --- datafusion/src/logical_plan/builder.rs | 6 ++- datafusion/src/logical_plan/expr.rs | 58 +++++++++++++++++++++++++- datafusion/src/logical_plan/mod.rs | 10 ++--- datafusion/tests/sql.rs | 17 ++++++++ 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 90d2ae22241e8..549e584d4927e 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -46,8 +46,8 @@ use std::{ use super::dfschema::ToDFSchema; use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType}; use crate::logical_plan::{ - columnize_expr, normalize_col, normalize_cols, Column, CrossJoin, DFField, DFSchema, - DFSchemaRef, Limit, Partitioning, Repartition, Values, + columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column, + CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values, }; use crate::sql::utils::group_window_expr_by_sort_keys; @@ -521,6 +521,8 @@ impl LogicalPlanBuilder { &self, exprs: impl IntoIterator> + Clone, ) -> Result { + let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?; + let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index fc862cd9ae376..d58ec18e138e9 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -21,7 +21,9 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; use crate::field_util::get_indexed_field; -use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan}; +use crate::logical_plan::{ + plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, +}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, @@ -1317,6 +1319,60 @@ pub fn normalize_cols( .collect() } +/// Rewrite sort on aggregate expressions to sort on the column of aggregate output +#[inline] +pub fn rewrite_sort_cols_by_aggs( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| { + let expr = e.into(); + match expr.clone() { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let sort = Expr::Sort { + expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), + asc, + nulls_first, + }; + Ok(sort) + } + _ => Ok(expr), + } + }) + .collect() +} + +fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr: _, + aggr_expr, + schema: _, + }) => { + let normalized_expr = normalize_col(expr.clone(), plan)?; + let found_agg = aggr_expr.into_iter().find(|a| (**a) == normalized_expr); + if found_agg.is_some() { + let agg = normalize_col(found_agg.unwrap().clone(), plan)?; + let col = Expr::Column( + agg.to_field(input.schema()).map(|f| f.qualified_column())?, + ); + Ok(col) + } else { + Ok(expr) + } + } + LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), + _ => Ok(expr), + } +} + /// Recursively 'unnormalize' (remove all qualifiers) from an /// expression tree. /// diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index a20d572067497..56fec3cf1a0c4 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -42,11 +42,11 @@ pub use expr::{ create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round, - rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, - starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, - unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter, - ExpressionVisitor, Literal, Recursion, RewriteRecursion, + regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, + signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, + translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, + Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 7c3210dd7599e..ed630053c36a2 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -391,6 +391,23 @@ async fn csv_query_with_is_null_predicate() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_order_by_agg_expr() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------------------+", + "| MIN(aggregate_test_100.c12) |", + "+-----------------------------+", + "| 0.01479305307777301 |", + "+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_int_min_max() -> Result<()> { let mut ctx = ExecutionContext::new(); From 23b9d3acd8d7519ab8ea2fdee40b9149be9622e2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Dec 2021 11:24:25 -0800 Subject: [PATCH 2/7] Use ExprRewriter. --- datafusion/src/logical_plan/builder.rs | 2 +- datafusion/src/logical_plan/expr.rs | 43 ++++++++++++++++++++------ datafusion/tests/sql/order.rs | 4 +++ 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 549e584d4927e..fc609390bcc0d 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -532,7 +532,7 @@ impl LogicalPlanBuilder { .into_iter() .try_for_each::<_, Result<()>>(|expr| { let mut columns: HashSet = HashSet::new(); - utils::expr_to_columns(&expr.into(), &mut columns)?; + utils::expr_to_columns(&expr, &mut columns)?; columns.into_iter().for_each(|c| { if schema.field_from_column(&c).is_err() { diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index d58ec18e138e9..00e70231fb66c 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1356,17 +1356,40 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { aggr_expr, schema: _, }) => { - let normalized_expr = normalize_col(expr.clone(), plan)?; - let found_agg = aggr_expr.into_iter().find(|a| (**a) == normalized_expr); - if found_agg.is_some() { - let agg = normalize_col(found_agg.unwrap().clone(), plan)?; - let col = Expr::Column( - agg.to_field(input.schema()).map(|f| f.qualified_column())?, - ); - Ok(col) - } else { - Ok(expr) + struct Rewriter<'a> { + plan: &'a LogicalPlan, + input: &'a LogicalPlan, + aggr_expr: &'a Vec, + } + + impl<'a> ExprRewriter for Rewriter<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + let normalized_expr = normalize_col(expr.clone(), self.plan); + if normalized_expr.is_err() { + // The expr is not based on Aggregate plan output. Skip it. + return Ok(expr); + } + let normalized_expr = normalized_expr.unwrap(); + let found_agg = + self.aggr_expr.iter().find(|a| (**a) == normalized_expr); + if found_agg.is_some() { + let agg = normalize_col(found_agg.unwrap().clone(), self.plan)?; + let col = Expr::Column( + agg.to_field(self.input.schema()) + .map(|f| f.qualified_column())?, + ); + Ok(col) + } else { + Ok(expr) + } + } } + + expr.rewrite(&mut Rewriter { + plan, + input, + aggr_expr, + }) } LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), _ => Ok(expr), diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs index 65c3959781dc8..fa59d9d196615 100644 --- a/datafusion/tests/sql/order.rs +++ b/datafusion/tests/sql/order.rs @@ -46,6 +46,10 @@ async fn test_order_by_agg_expr() -> Result<()> { "+-----------------------------+", ]; assert_batches_eq!(expected, &actual); + + let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 0.1"; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); Ok(()) } From f2ce277223450880a7d127ab3c4e7deaf5571e12 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Dec 2021 22:30:15 -0800 Subject: [PATCH 3/7] For review comment --- datafusion/src/logical_plan/expr.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 00e70231fb66c..c211f3bf4b953 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1308,7 +1308,6 @@ fn normalize_col_with_schemas( } /// Recursively normalize all Column expressions in a list of expression trees -#[inline] pub fn normalize_cols( exprs: impl IntoIterator>, plan: &LogicalPlan, @@ -1320,7 +1319,6 @@ pub fn normalize_cols( } /// Rewrite sort on aggregate expressions to sort on the column of aggregate output -#[inline] pub fn rewrite_sort_cols_by_aggs( exprs: impl IntoIterator>, plan: &LogicalPlan, @@ -1351,10 +1349,7 @@ pub fn rewrite_sort_cols_by_aggs( fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { - input, - group_expr: _, - aggr_expr, - schema: _, + input, aggr_expr, .. }) => { struct Rewriter<'a> { plan: &'a LogicalPlan, From 6374b50ff976ecc69a0dfcbf934bed919247d6e3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Dec 2021 10:40:08 -0800 Subject: [PATCH 4/7] Update datafusion/src/logical_plan/expr.rs Co-authored-by: Andrew Lamb --- datafusion/src/logical_plan/expr.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index c211f3bf4b953..af492b67cd83f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1327,7 +1327,7 @@ pub fn rewrite_sort_cols_by_aggs( .into_iter() .map(|e| { let expr = e.into(); - match expr.clone() { + match expr { Expr::Sort { expr, asc, @@ -1340,7 +1340,7 @@ pub fn rewrite_sort_cols_by_aggs( }; Ok(sort) } - _ => Ok(expr), + expr => Ok(expr), } }) .collect() From 20b590d9e4742a49a0872582cc57718448c9d743 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Dec 2021 10:40:16 -0800 Subject: [PATCH 5/7] Update datafusion/src/logical_plan/expr.rs Co-authored-by: Andrew Lamb --- datafusion/src/logical_plan/expr.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index af492b67cd83f..4eb2344e3559b 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1319,6 +1319,7 @@ pub fn normalize_cols( } /// Rewrite sort on aggregate expressions to sort on the column of aggregate output +/// For example, `max(x)` is written to `col("MAX(x)")` pub fn rewrite_sort_cols_by_aggs( exprs: impl IntoIterator>, plan: &LogicalPlan, From 52499971ef96e73c009e5f8a11b20092fff4f9ed Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Dec 2021 10:40:22 -0800 Subject: [PATCH 6/7] Update datafusion/src/logical_plan/expr.rs Co-authored-by: Andrew Lamb --- datafusion/src/logical_plan/expr.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 4eb2344e3559b..3ef522b8c2eba 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1366,10 +1366,8 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { return Ok(expr); } let normalized_expr = normalized_expr.unwrap(); - let found_agg = - self.aggr_expr.iter().find(|a| (**a) == normalized_expr); - if found_agg.is_some() { - let agg = normalize_col(found_agg.unwrap().clone(), self.plan)?; + if let Some(found_agg) = self.aggr_expr.iter().find(|a| (**a) == normalized_expr) { + let agg = normalize_col(found_agg, self.plan)?; let col = Expr::Column( agg.to_field(self.input.schema()) .map(|f| f.qualified_column())?, From 927a2cfce2725b20371999980b5c56368c1fb7a2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Dec 2021 10:53:09 -0800 Subject: [PATCH 7/7] Fix format. --- datafusion/src/logical_plan/expr.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 3ef522b8c2eba..dadc168530745 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1319,7 +1319,7 @@ pub fn normalize_cols( } /// Rewrite sort on aggregate expressions to sort on the column of aggregate output -/// For example, `max(x)` is written to `col("MAX(x)")` +/// For example, `max(x)` is written to `col("MAX(x)")` pub fn rewrite_sort_cols_by_aggs( exprs: impl IntoIterator>, plan: &LogicalPlan, @@ -1366,8 +1366,10 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { return Ok(expr); } let normalized_expr = normalized_expr.unwrap(); - if let Some(found_agg) = self.aggr_expr.iter().find(|a| (**a) == normalized_expr) { - let agg = normalize_col(found_agg, self.plan)?; + if let Some(found_agg) = + self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + { + let agg = normalize_col(found_agg.clone(), self.plan)?; let col = Expr::Column( agg.to_field(self.input.schema()) .map(|f| f.qualified_column())?,