Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ use std::{
use super::dfschema::ToDFSchema;
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, Column, CrossJoin, DFField, DFSchema,
DFSchemaRef, Limit, Partitioning, Repartition, Values,
columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column,
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values,
};
use crate::sql::utils::group_window_expr_by_sort_keys;

Expand Down Expand Up @@ -521,6 +521,8 @@ impl LogicalPlanBuilder {
&self,
exprs: impl IntoIterator<Item = impl Into<Expr>> + Clone,
) -> Result<Self> {
let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?;

let schema = self.plan.schema();

// Collect sort columns that are missing in the input plan's schema
Expand All @@ -530,7 +532,7 @@ impl LogicalPlanBuilder {
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
let mut columns: HashSet<Column> = HashSet::new();
utils::expr_to_columns(&expr.into(), &mut columns)?;
utils::expr_to_columns(&expr, &mut columns)?;

columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
Expand Down
79 changes: 77 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
use crate::field_util::get_indexed_field;
use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan};
use crate::logical_plan::{
plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan,
};
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
Expand Down Expand Up @@ -1306,7 +1308,6 @@ fn normalize_col_with_schemas(
}

/// Recursively normalize all Column expressions in a list of expression trees
#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
Expand All @@ -1317,6 +1318,80 @@ pub fn normalize_cols(
.collect()
}

/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
/// For example, `max(x)` is written to `col("MAX(x)")`
pub fn rewrite_sort_cols_by_aggs(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.map(|e| {
let expr = e.into();
match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let sort = Expr::Sort {
expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
asc,
nulls_first,
};
Ok(sort)
}
expr => Ok(expr),
}
})
.collect()
}

fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input, aggr_expr, ..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
}

impl<'a> ExprRewriter for Rewriter<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let normalized_expr = normalize_col(expr.clone(), self.plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
let normalized_expr = normalized_expr.unwrap();
if let Some(found_agg) =
self.aggr_expr.iter().find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
agg.to_field(self.input.schema())
.map(|f| f.qualified_column())?,
);
Ok(col)
} else {
Ok(expr)
}
}
}

expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
})
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
_ => Ok(expr),
}
}

/// Recursively 'unnormalize' (remove all qualifiers) from an
/// expression tree.
///
Expand Down
10 changes: 5 additions & 5 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ pub use expr::{
create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim,
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round,
rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias,
unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
ExpressionVisitor, Literal, Recursion, RewriteRecursion,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
21 changes: 21 additions & 0 deletions datafusion/tests/sql/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ async fn test_sort_unprojected_col() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_order_by_agg_expr() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx).await?;
let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-----------------------------+",
"| MIN(aggregate_test_100.c12) |",
"+-----------------------------+",
"| 0.01479305307777301 |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 0.1";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let actual = execute_to_batches(&mut ctx, sql).await;
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn test_nulls_first_asc() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down