Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ use std::{
use super::dfschema::ToDFSchema;
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, Column, CrossJoin, DFField, DFSchema,
DFSchemaRef, Limit, Partitioning, Repartition, Values,
columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column,
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values,
};
use crate::sql::utils::group_window_expr_by_sort_keys;

Expand Down Expand Up @@ -521,6 +521,8 @@ impl LogicalPlanBuilder {
&self,
exprs: impl IntoIterator<Item = impl Into<Expr>> + Clone,
) -> Result<Self> {
let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?;

let schema = self.plan.schema();

// Collect sort columns that are missing in the input plan's schema
Expand All @@ -530,7 +532,7 @@ impl LogicalPlanBuilder {
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
let mut columns: HashSet<Column> = HashSet::new();
utils::expr_to_columns(&expr.into(), &mut columns)?;
utils::expr_to_columns(&expr, &mut columns)?;

columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
Expand Down
78 changes: 76 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
use crate::field_util::get_indexed_field;
use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan};
use crate::logical_plan::{
plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan,
};
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
Expand Down Expand Up @@ -1306,7 +1308,6 @@ fn normalize_col_with_schemas(
}

/// Recursively normalize all Column expressions in a list of expression trees
#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
Expand All @@ -1317,6 +1318,79 @@ pub fn normalize_cols(
.collect()
}

/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
Comment thread
viirya marked this conversation as resolved.
pub fn rewrite_sort_cols_by_aggs(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.map(|e| {
let expr = e.into();
match expr.clone() {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let sort = Expr::Sort {
expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
asc,
nulls_first,
};
Ok(sort)
}
_ => Ok(expr),
Comment thread
viirya marked this conversation as resolved.
Outdated
}
})
.collect()
}

fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input, aggr_expr, ..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
}

impl<'a> ExprRewriter for Rewriter<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let normalized_expr = normalize_col(expr.clone(), self.plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
let normalized_expr = normalized_expr.unwrap();
let found_agg =
self.aggr_expr.iter().find(|a| (**a) == normalized_expr);
if found_agg.is_some() {
let agg = normalize_col(found_agg.unwrap().clone(), self.plan)?;
let col = Expr::Column(
agg.to_field(self.input.schema())
.map(|f| f.qualified_column())?,
);
Ok(col)
} else {
Ok(expr)
}
Comment thread
viirya marked this conversation as resolved.
Outdated
}
}

expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
})
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
_ => Ok(expr),
}
}

/// Recursively 'unnormalize' (remove all qualifiers) from an
/// expression tree.
///
Expand Down
10 changes: 5 additions & 5 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ pub use expr::{
create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim,
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round,
rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias,
unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
ExpressionVisitor, Literal, Recursion, RewriteRecursion,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
21 changes: 21 additions & 0 deletions datafusion/tests/sql/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ async fn test_sort_unprojected_col() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_order_by_agg_expr() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx).await?;
let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-----------------------------+",
"| MIN(aggregate_test_100.c12) |",
"+-----------------------------+",
"| 0.01479305307777301 |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 0.1";

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let actual = execute_to_batches(&mut ctx, sql).await;
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn test_nulls_first_asc() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down