From 85adb6c4e6c0b6009f9866118c318b078263e118 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 29 Aug 2024 17:38:28 +0200 Subject: [PATCH] Remove Sort expression (`Expr::Sort`) (#12177) * Take Sort (SortExpr) in file options Part of effort to remove `Expr::Sort`. * Return Sort from Expr.Sort Part of effort to remove `Expr::Sort`. * Accept Sort (SortExpr) in `LogicalPlanBuilder.sort` Take `expr::Sort` in `LogicalPlanBuilder.sort`. Accept any `Expr` in new function, `LogicalPlanBuilder.sort_by` which apply default sort ordering. Part of effort to remove `Expr::Sort`. * Operate on `Sort` in to_substrait_sort_field / from_substrait_sorts Part of effort to remove `Expr::Sort`. * Take Sort (SortExpr) in tests' TopKPlanNode Part of effort to remove `Expr::Sort`. * Remove Sort expression (`Expr::Sort`) Remove sort as an expression, i.e. remove `Expr::Sort` from `Expr` enum. Use `expr::Sort` directly when sorting. The sort expression was used in context of ordering (sort, topk, create table, file sorting). Those places require their sort expression to be of type Sort anyway and no other expression was allowed, so this change improves static typing. Sort as an expression was illegal in other contexts. * use assert_eq just like in LogicalPlan.with_new_exprs * avoid clone in replace_sort_expressions * reduce cloning in EliminateDuplicatedExpr * restore SortExprWrapper this commit is longer than advised in the review comment, but after squashing the diff will be smaller * shorthand SortExprWrapper struct definition --- .../examples/file_stream_provider.rs | 4 +- datafusion/core/src/dataframe/mod.rs | 15 +- .../src/datasource/file_format/options.rs | 14 +- .../core/src/datasource/listing/helpers.rs | 3 +- .../core/src/datasource/listing/table.rs | 14 +- datafusion/core/src/datasource/memory.rs | 5 +- datafusion/core/src/datasource/mod.rs | 41 ++-- .../physical_plan/file_scan_config.rs | 2 +- datafusion/core/src/datasource/stream.rs | 6 +- datafusion/core/src/physical_planner.rs | 32 ++- datafusion/core/src/test_util/mod.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 20 +- datafusion/core/tests/expr_api/mod.rs | 17 +- datafusion/core/tests/fifo/mod.rs | 4 +- .../core/tests/fuzz_cases/limit_fuzz.rs | 2 +- .../tests/user_defined/user_defined_plan.rs | 9 +- datafusion/expr/src/expr.rs | 113 +++++------ datafusion/expr/src/expr_fn.rs | 28 +-- datafusion/expr/src/expr_rewriter/mod.rs | 34 ++-- datafusion/expr/src/expr_rewriter/order_by.rs | 38 ++-- datafusion/expr/src/expr_schema.rs | 15 +- datafusion/expr/src/logical_plan/builder.rs | 38 ++-- datafusion/expr/src/logical_plan/ddl.rs | 7 +- datafusion/expr/src/logical_plan/plan.rs | 41 ++-- datafusion/expr/src/logical_plan/tree_node.rs | 18 +- datafusion/expr/src/tree_node.rs | 53 ++++- datafusion/expr/src/utils.rs | 192 +++++++----------- datafusion/expr/src/window_frame.rs | 4 +- .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 7 +- .../optimizer/src/common_subexpr_eliminate.rs | 15 +- .../src/eliminate_duplicated_expr.rs | 30 +-- datafusion/optimizer/src/eliminate_limit.rs | 12 +- datafusion/optimizer/src/push_down_filter.rs | 3 +- datafusion/optimizer/src/push_down_limit.rs | 8 +- .../simplify_expressions/expr_simplifier.rs | 1 - .../src/single_distinct_to_groupby.rs | 4 +- datafusion/proto/proto/datafusion.proto | 17 +- datafusion/proto/src/generated/pbjson.rs | 105 ++++++++-- datafusion/proto/src/generated/prost.rs | 26 ++- .../proto/src/logical_plan/from_proto.rs | 49 +++-- datafusion/proto/src/logical_plan/mod.rs | 40 ++-- datafusion/proto/src/logical_plan/to_proto.rs | 48 +++-- .../tests/cases/roundtrip_logical_plan.rs | 10 +- datafusion/sql/src/expr/function.rs | 25 +-- datafusion/sql/src/expr/order_by.rs | 8 +- datafusion/sql/src/query.rs | 3 +- datafusion/sql/src/select.rs | 4 +- datafusion/sql/src/statement.rs | 13 +- datafusion/sql/src/unparser/expr.rs | 120 +++-------- datafusion/sql/src/unparser/mod.rs | 2 - datafusion/sql/src/unparser/plan.rs | 17 +- datafusion/sql/src/unparser/rewrite.rs | 37 ++-- .../substrait/src/logical_plan/consumer.rs | 12 +- .../substrait/src/logical_plan/producer.rs | 77 ++++--- .../using-the-dataframe-api.md | 4 +- 57 files changed, 704 insertions(+), 772 deletions(-) diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/file_stream_provider.rs index 4db7e0200f53..e4fd937fd373 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/file_stream_provider.rs @@ -39,7 +39,7 @@ mod non_windows { use datafusion::datasource::TableProvider; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{exec_err, Result}; - use datafusion_expr::Expr; + use datafusion_expr::SortExpr; // Number of lines written to FIFO const TEST_BATCH_SIZE: usize = 5; @@ -49,7 +49,7 @@ mod non_windows { fn fifo_table( schema: SchemaRef, path: impl Into, - sort: Vec>, + sort: Vec>, ) -> Arc { let source = FileStreamProvider::new_file(schema, path.into()) .with_batch_size(TEST_BATCH_SIZE) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c516c7985d54..5dbeb535a546 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,7 +52,7 @@ use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; -use datafusion_expr::{case, is_null, lit}; +use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; @@ -577,7 +577,7 @@ impl DataFrame { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { let plan = LogicalPlanBuilder::from(self.plan) .distinct_on(on_expr, select_expr, sort_expr)? @@ -776,6 +776,15 @@ impl DataFrame { }) } + /// Apply a sort by provided expressions with default direction + pub fn sort_by(self, expr: Vec) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.sort(true, false)) + .collect::>(), + ) + } + /// Sort the DataFrame by the specified sorting expressions. /// /// Note that any expression can be turned into @@ -797,7 +806,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn sort(self, expr: Vec) -> Result { + pub fn sort(self, expr: Vec) -> Result { let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 552977baba17..db90262edbf8 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -31,7 +31,6 @@ use crate::datasource::{ }; use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; -use crate::logical_expr::Expr; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::config::TableOptions; @@ -41,6 +40,7 @@ use datafusion_common::{ }; use async_trait::async_trait; +use datafusion_expr::SortExpr; /// Options that control the reading of CSV files. /// @@ -84,7 +84,7 @@ pub struct CsvReadOptions<'a> { /// File compression type pub file_compression_type: FileCompressionType, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for CsvReadOptions<'a> { @@ -199,7 +199,7 @@ impl<'a> CsvReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -231,7 +231,7 @@ pub struct ParquetReadOptions<'a> { /// based on data in file. pub schema: Option<&'a Schema>, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for ParquetReadOptions<'a> { @@ -278,7 +278,7 @@ impl<'a> ParquetReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -397,7 +397,7 @@ pub struct NdJsonReadOptions<'a> { /// Flag indicating whether this file may be unbounded (as in a FIFO file). pub infinite: bool, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -452,7 +452,7 @@ impl<'a> NdJsonReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index f6e938b72dab..dbeaf5dfcc36 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -102,11 +102,10 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { } // TODO other expressions are not handled yet: - // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases + // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context Expr::AggregateFunction { .. } - | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard { .. } | Expr::Unnest { .. } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a0345a38e40c..1f5fa738b253 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -33,8 +33,8 @@ use crate::datasource::{ use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::TableType; use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; +use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; @@ -222,7 +222,7 @@ pub struct ListingOptions { /// ordering (encapsulated by a `Vec`). If there aren't /// multiple equivalent orderings, the outer `Vec` will have a /// single element. - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl ListingOptions { @@ -385,7 +385,7 @@ impl ListingOptions { /// /// assert_eq!(listing_options.file_sort_order, file_sort_order); /// ``` - pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -909,8 +909,7 @@ impl TableProvider for ListingTable { keep_partition_by_columns, }; - let unsorted: Vec> = vec![]; - let order_requirements = if self.options().file_sort_order != unsorted { + let order_requirements = if !self.options().file_sort_order.is_empty() { // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? @@ -1160,11 +1159,6 @@ mod tests { // (file_sort_order, expected_result) let cases = vec![ (vec![], Ok(vec![])), - // not a sort expr - ( - vec![vec![col("string_col")]], - Err("Expected Expr::Sort in output_ordering, but got string_col"), - ), // sort expr, but non column ( vec![vec![ diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 44e01e71648a..cef7f210e118 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -43,6 +43,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_expr::SortExpr; use futures::StreamExt; use log::debug; use parking_lot::Mutex; @@ -64,7 +65,7 @@ pub struct MemTable { column_defaults: HashMap, /// Optional pre-known sort order(s). Must be `SortExpr`s. /// inserting data into this table removes the order - pub sort_order: Arc>>>, + pub sort_order: Arc>>>, } impl MemTable { @@ -118,7 +119,7 @@ impl MemTable { /// /// Note that multiple sort orders are supported, if some are known to be /// equivalent, - pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { + pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order); self } diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 1c9924735735..55e88e572be1 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -50,38 +50,39 @@ pub use statistics::get_statistics_with_limit; use arrow_schema::{Schema, SortOptions}; use datafusion_common::{plan_err, Result}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; fn create_ordering( schema: &Schema, - sort_order: &[Vec], + sort_order: &[Vec], ) -> Result> { let mut all_sort_orders = vec![]; for exprs in sort_order { // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; - for expr in exprs { - match expr { - Expr::Sort(sort) => match sort.expr.as_ref() { - Expr::Column(col) => match expressions::col(&col.name, schema) { - Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - // Cannot find expression in the projected_schema, stop iterating - // since rest of the orderings are violated - Err(_) => break, + for sort in exprs { + match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } - expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + }, + expr => { + return plan_err!( + "Expected single column references in output_ordering, got {expr}" + ) } - expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), } } if !sort_exprs.is_empty() { diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index bfa5488e5b5e..3ea467539adc 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -979,7 +979,7 @@ mod tests { name: &'static str, file_schema: Schema, files: Vec, - sort: Vec, + sort: Vec, expected_result: Result>, &'static str>, } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index b53fe8663178..ef6d195cdaff 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -33,7 +33,7 @@ use arrow_schema::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; @@ -248,7 +248,7 @@ impl StreamProvider for FileStreamProvider { #[derive(Debug)] pub struct StreamConfig { source: Arc, - order: Vec>, + order: Vec>, constraints: Constraints, } @@ -263,7 +263,7 @@ impl StreamConfig { } /// Specify a sort order for the stream - pub fn with_order(mut self, order: Vec>) -> Self { + pub fn with_order(mut self, order: Vec>) -> Self { self.order = order; self } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fe8d79846630..82405dd98e30 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -73,13 +73,13 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, - WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -1641,31 +1641,27 @@ pub fn create_aggregate_expr_and_maybe_filter( /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( - e: &Expr, + e: &SortExpr, input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - if let Expr::Sort(expr::Sort { + let SortExpr { expr, asc, nulls_first, - }) = e - { - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - internal_err!("Expects a sort expression") - } + } = e; + Ok(PhysicalSortExpr { + expr: create_physical_expr(expr, input_dfschema, execution_props)?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) } /// Create vector of physical sort expression from a vector of logical expression pub fn create_physical_sort_exprs( - exprs: &[Expr], + exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index faa9378535fd..dd8b697666ee 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -46,7 +46,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::TableReference; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::{expressions, EquivalenceProperties, PhysicalExpr}; @@ -360,7 +360,7 @@ pub fn register_unbounded_file_with_ordering( schema: SchemaRef, file_path: &Path, table_name: &str, - file_sort_order: Vec>, + file_sort_order: Vec>, ) -> Result<()> { let source = FileStreamProvider::new_file(schema, file_path.into()); let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 86cacbaa06d8..c5b9db7588e9 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -184,7 +184,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), @@ -352,7 +352,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { .unwrap() .select(vec![col("a")]) .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .sort(vec![Sort::new(Box::new(col("b")), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -396,7 +396,7 @@ async fn sort_on_distinct_columns() -> Result<()> { .unwrap() .distinct() .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .sort(vec![Sort::new(Box::new(col("a")), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -435,7 +435,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { .await? .select(vec![col("a")])? .distinct()? - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .sort(vec![Sort::new(Box::new(col("b")), false, true)]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list"); Ok(()) @@ -599,8 +599,8 @@ async fn test_grouping_sets() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(col("a"))])? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("a")), false, true)), - Expr::Sort(Sort::new(Box::new(col("b")), false, true)), + Sort::new(Box::new(col("a")), false, true), + Sort::new(Box::new(col("b")), false, true), ])?; let results = df.collect().await?; @@ -640,8 +640,8 @@ async fn test_grouping_sets_count() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(lit(1))])? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("c1")), false, true)), - Expr::Sort(Sort::new(Box::new(col("c2")), false, true)), + Sort::new(Box::new(col("c1")), false, true), + Sort::new(Box::new(col("c2")), false, true), ])?; let results = df.collect().await?; @@ -687,8 +687,8 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { ], )? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("c1")), false, true)), - Expr::Sort(Sort::new(Box::new(col("c2")), false, true)), + Sort::new(Box::new(col("c1")), false, true), + Sort::new(Box::new(col("c2")), false, true), ])?; let results = df.collect().await?; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 051d65652633..cbd892672152 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -20,7 +20,7 @@ use arrow_array::builder::{ListBuilder, StringBuilder}; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; -use datafusion_common::{assert_contains, DFSchema, ScalarValue}; +use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; @@ -167,21 +167,6 @@ fn test_list_range() { ); } -#[tokio::test] -async fn test_aggregate_error() { - let err = first_value_udaf() - .call(vec![col("props")]) - // not a sort column - .order_by(vec![col("id")]) - .build() - .unwrap_err() - .to_string(); - assert_contains!( - err, - "Error during planning: ORDER BY expressions must be Expr::Sort" - ); -} - #[tokio::test] async fn test_aggregate_ext_order_by() { let agg = first_value_udaf().call(vec![col("props")]); diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 6efbb9b029de..cb587e3510c2 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -38,7 +38,7 @@ mod unix_test { }; use datafusion_common::instant::Instant; use datafusion_common::{exec_err, Result}; - use datafusion_expr::Expr; + use datafusion_expr::SortExpr; use futures::StreamExt; use nix::sys::stat; @@ -51,7 +51,7 @@ mod unix_test { fn fifo_table( schema: SchemaRef, path: impl Into, - sort: Vec>, + sort: Vec>, ) -> Arc { let source = FileStreamProvider::new_file(schema, path.into()) .with_batch_size(TEST_BATCH_SIZE) diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 9889ce2ae562..95d97709f319 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -226,7 +226,7 @@ impl SortedData { } /// Return the sort expression to use for this data, depending on the type - fn sort_expr(&self) -> Vec { + fn sort_expr(&self) -> Vec { match self { Self::I32 { .. } | Self::F64 { .. } | Self::Str { .. } => { vec![datafusion_expr::col("x").sort(true, true)] diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 62ba113da0d3..da27cf8869d1 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -97,7 +97,8 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::Projection; +use datafusion_expr::tree_node::replace_sort_expression; +use datafusion_expr::{Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -392,7 +393,7 @@ struct TopKPlanNode { input: LogicalPlan, /// The sort expression (this example only supports a single sort /// expr) - expr: Expr, + expr: SortExpr, } impl Debug for TopKPlanNode { @@ -418,7 +419,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { } fn expressions(&self) -> Vec { - vec![self.expr.clone()] + vec![self.expr.expr.as_ref().clone()] } /// For example: `TopK: k=10` @@ -436,7 +437,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { Ok(Self { k: self.k, input: inputs.swap_remove(0), - expr: exprs.swap_remove(0), + expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), }) } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 85ba80396c8e..b81c02ccd0b7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -289,10 +289,6 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// A sort expression, that can be used to sort values. - /// - /// See [Expr::sort] for more details - Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional @@ -633,6 +629,23 @@ impl Sort { } } +impl Display for Sort { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.expr)?; + if self.asc { + write!(f, " ASC")?; + } else { + write!(f, " DESC")?; + } + if self.nulls_first { + write!(f, " NULLS FIRST")?; + } else { + write!(f, " NULLS LAST")?; + } + Ok(()) + } +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` @@ -649,7 +662,7 @@ pub struct AggregateFunction { /// Optional filter pub filter: Option>, /// Optional ordering - pub order_by: Option>, + pub order_by: Option>, pub null_treatment: Option, } @@ -660,7 +673,7 @@ impl AggregateFunction { args: Vec, distinct: bool, filter: Option>, - order_by: Option>, + order_by: Option>, null_treatment: Option, ) -> Self { Self { @@ -785,7 +798,7 @@ pub struct WindowFunction { /// List of partition by expressions pub partition_by: Vec, /// List of order by expressions - pub order_by: Vec, + pub order_by: Vec, /// Window frame pub window_frame: window_frame::WindowFrame, /// Specifies how NULL value is treated: ignore or respect @@ -1141,7 +1154,6 @@ impl Expr { Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarVariable(..) => "ScalarVariable", - Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard { .. } => "Wildcard", @@ -1227,14 +1239,9 @@ impl Expr { Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) } - /// Return the name to use for the specific Expr, recursing into - /// `Expr::Sort` as appropriate + /// Return the name to use for the specific Expr pub fn name_for_alias(&self) -> Result { - match self { - // call Expr::display_name() on a Expr::Sort will throw an error - Expr::Sort(Sort { expr, .. }) => expr.name_for_alias(), - expr => Ok(expr.schema_name().to_string()), - } + Ok(self.schema_name().to_string()) } /// Ensure `expr` has the name as `original_name` by adding an @@ -1250,14 +1257,7 @@ impl Expr { /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), - } + Expr::Alias(Alias::new(self, None::<&str>, name.into())) } /// Return `self AS name` alias expression with a specific qualifier @@ -1266,18 +1266,7 @@ impl Expr { relation: Option>, name: impl Into, ) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new( - Box::new(expr.alias_qualified(relation, name)), - asc, - nulls_first, - )), - _ => Expr::Alias(Alias::new(self, relation, name.into())), - } + Expr::Alias(Alias::new(self, relation, name.into())) } /// Remove an alias from an expression if one exists. @@ -1372,14 +1361,14 @@ impl Expr { Expr::IsNotNull(Box::new(self)) } - /// Create a sort expression from an existing expression. + /// Create a sort configuration from an existing expression. /// /// ``` /// # use datafusion_expr::col; /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort(Sort::new(Box::new(self), asc, nulls_first)) + pub fn sort(self, asc: bool, nulls_first: bool) -> Sort { + Sort::new(Box::new(self), asc, nulls_first) } /// Return `IsTrue(Box(self))` @@ -1655,7 +1644,6 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Sort(..) | Expr::Placeholder(..) => false, } } @@ -1752,14 +1740,6 @@ impl Expr { }) => { data_type.hash(hasher); } - Expr::Sort(Sort { - expr: _expr, - asc, - nulls_first, - }) => { - asc.hash(hasher); - nulls_first.hash(hasher); - } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(hasher); } @@ -1871,7 +1851,6 @@ impl<'a> Display for SchemaDisplay<'a> { Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(..) - | Expr::Sort(_) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), @@ -1901,7 +1880,7 @@ impl<'a> Display for SchemaDisplay<'a> { }; if let Some(order_by) = order_by { - write!(f, " ORDER BY [{}]", schema_name_from_exprs(order_by)?)?; + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; }; Ok(()) @@ -2107,7 +2086,7 @@ impl<'a> Display for SchemaDisplay<'a> { } if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_exprs(order_by)?)?; + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; }; write!(f, " {window_frame}") @@ -2144,6 +2123,24 @@ fn schema_name_from_exprs_inner(exprs: &[Expr], sep: &str) -> Result Result { + let mut s = String::new(); + for (i, e) in sorts.iter().enumerate() { + if i > 0 { + write!(&mut s, ", ")?; + } + let ordering = if e.asc { "ASC" } else { "DESC" }; + let nulls_ordering = if e.nulls_first { + "NULLS FIRST" + } else { + "NULLS LAST" + }; + write!(&mut s, "{} {} {}", e.expr, ordering, nulls_ordering)?; + } + + Ok(s) +} + /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. impl fmt::Display for Expr { @@ -2203,22 +2200,6 @@ impl fmt::Display for Expr { }) => write!(f, "{expr} IN ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - if *asc { - write!(f, "{expr} ASC")?; - } else { - write!(f, "{expr} DESC")?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1e0b601146dd..8d01712b95ad 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,9 +26,9 @@ use crate::function::{ StateFieldsArgs, }; use crate::{ - conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr, - LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, - Volatility, + conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, + AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, + Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -723,9 +723,7 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// ``` pub trait ExprFunctionExt { /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; /// Add `FILTER ` fn filter(self, filter: Expr) -> ExprFuncBuilder; /// Add `DISTINCT` @@ -753,7 +751,7 @@ pub enum ExprFuncKind { #[derive(Debug, Clone)] pub struct ExprFuncBuilder { fun: Option, - order_by: Option>, + order_by: Option>, filter: Option, distinct: bool, null_treatment: Option, @@ -798,16 +796,6 @@ impl ExprFuncBuilder { ); }; - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { udaf.order_by = order_by; @@ -833,9 +821,7 @@ impl ExprFuncBuilder { impl ExprFunctionExt for ExprFuncBuilder { /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { self.order_by = Some(order_by); self } @@ -873,7 +859,7 @@ impl ExprFunctionExt for ExprFuncBuilder { } impl ExprFunctionExt for Expr { - fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { Expr::AggregateFunction(udaf) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 768c4aabc840..b809b015d929 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; -use crate::expr::{Alias, Unnest}; +use crate::expr::{Alias, Sort, Unnest}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; @@ -117,6 +117,20 @@ pub fn normalize_cols( .collect() } +pub fn normalize_sorts( + sorts: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + sorts + .into_iter() + .map(|e| { + let sort = e.into(); + normalize_col(*sort.expr, plan) + .map(|expr| Sort::new(Box::new(expr), sort.asc, sort.nulls_first)) + }) + .collect() +} + /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { @@ -335,7 +349,6 @@ mod test { use std::ops::Add; use super::*; - use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; @@ -496,12 +509,6 @@ mod test { // change literal type from i32 to i64 test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); - - // SortExpr a+1 ==> b + 2 - test_rewrite( - Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), - Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), - ); } /// rewrites `expr_from` to `rewrite_to` using @@ -524,15 +531,8 @@ mod test { }; let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); - let original_name = match &expr_from { - Expr::Sort(Sort { expr, .. }) => expr.schema_name().to_string(), - expr => expr.schema_name().to_string(), - }; - - let new_name = match &expr { - Expr::Sort(Sort { expr, .. }) => expr.schema_name().to_string(), - expr => expr.schema_name().to_string(), - }; + let original_name = expr_from.schema_name().to_string(); + let new_name = expr.schema_name().to_string(); assert_eq!( original_name, new_name, diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index bbb855801c3e..af5b8c4f9177 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -17,9 +17,9 @@ //! Rewrite for order by expressions -use crate::expr::{Alias, Sort}; +use crate::expr::Alias; use crate::expr_rewriter::normalize_col; -use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; +use crate::{expr::Sort, Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; @@ -27,28 +27,18 @@ use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output /// For example, `max(x)` is written to `col("max(x)")` pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, + sorts: impl IntoIterator>, plan: &LogicalPlan, -) -> Result> { - exprs +) -> Result> { + sorts .into_iter() .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sort = Expr::Sort(Sort::new( - Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - )); - Ok(sort) - } - expr => Ok(expr), - } + let sort = e.into(); + Ok(Sort::new( + Box::new(rewrite_sort_col_by_aggs(*sort.expr, plan)?), + sort.asc, + sort.nulls_first, + )) }) .collect() } @@ -289,8 +279,8 @@ mod test { struct TestCase { desc: &'static str, - input: Expr, - expected: Expr, + input: Sort, + expected: Sort, } impl TestCase { @@ -332,7 +322,7 @@ mod test { .unwrap() } - fn sort(expr: Expr) -> Expr { + fn sort(expr: Expr) -> Sort { let asc = true; let nulls_first = true; expr.sort(asc, nulls_first) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3920a1a3517c..894b7e58d954 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -18,7 +18,7 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, - ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ @@ -107,7 +107,7 @@ impl ExprSchemable for Expr { }, _ => expr.get_type(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), + Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), @@ -280,10 +280,9 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, .. }) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema), + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { + expr.nullable(input_schema) + } Expr::InList(InList { expr, list, .. }) => { // Avoid inspecting too many expressions. @@ -422,9 +421,7 @@ impl ExprSchemable for Expr { }, _ => expr.data_type_and_nullable(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => { - expr.data_type_and_nullable(schema) - } + Expr::Negative(expr) => expr.data_type_and_nullable(schema), Expr::Column(c) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2c2300b123c2..f5770167861b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -23,10 +23,10 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::Alias; +use crate::expr::{Alias, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, - normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ @@ -541,19 +541,31 @@ impl LogicalPlanBuilder { plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") } + /// Apply a sort by provided expressions with default direction + pub fn sort_by( + self, + expr: impl IntoIterator> + Clone, + ) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.into().sort(true, false)) + .collect::>(), + ) + } + /// Apply a sort pub fn sort( self, - exprs: impl IntoIterator> + Clone, + sorts: impl IntoIterator> + Clone, ) -> Result { - let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?; + let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?; let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema let mut missing_cols: Vec = vec![]; - exprs.iter().try_for_each::<_, Result<()>>(|expr| { - let columns = expr.column_refs(); + sorts.iter().try_for_each::<_, Result<()>>(|sort| { + let columns = sort.expr.column_refs(); columns.into_iter().for_each(|c| { if !schema.has_column(c) { @@ -566,7 +578,7 @@ impl LogicalPlanBuilder { if missing_cols.is_empty() { return Ok(Self::new(LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &self.plan)?, + expr: normalize_sorts(sorts, &self.plan)?, input: self.plan, fetch: None, }))); @@ -582,7 +594,7 @@ impl LogicalPlanBuilder { is_distinct, )?; let sort_plan = LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &plan)?, + expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch: None, }); @@ -618,7 +630,7 @@ impl LogicalPlanBuilder { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { Ok(Self::new(LogicalPlan::Distinct(Distinct::On( DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, @@ -1708,8 +1720,8 @@ mod tests { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + expr::Sort::new(Box::new(col("state")), true, true), + expr::Sort::new(Box::new(col("salary")), false, false), ])? .build()?; @@ -2135,8 +2147,8 @@ mod tests { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + expr::Sort::new(Box::new(col("state")), true, true), + expr::Sort::new(Box::new(col("salary")), false, false), ])? .build()?; diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index ad0fcd2d4771..3fc43200efe6 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -22,8 +22,9 @@ use std::{ hash::{Hash, Hasher}, }; -use crate::{Expr, LogicalPlan, Volatility}; +use crate::{Expr, LogicalPlan, SortExpr, Volatility}; +use crate::expr::Sort; use arrow::datatypes::DataType; use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; use sqlparser::ast::Ident; @@ -204,7 +205,7 @@ pub struct CreateExternalTable { /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user - pub order_exprs: Vec>, + pub order_exprs: Vec>, /// Whether the table is an infinite streams pub unbounded: bool, /// Table(provider) specific options @@ -365,7 +366,7 @@ pub struct CreateIndex { pub name: Option, pub table: TableReference, pub using: Option, - pub columns: Vec, + pub columns: Vec, pub unique: bool, pub if_not_exists: bool, pub schema: DFSchemaRef, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 359de2d30a57..8e6ec762f549 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -26,7 +26,9 @@ use super::dml::CopyTo; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; -use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols, NamePreserver}; +use crate::expr_rewriter::{ + create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, +}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; @@ -51,6 +53,7 @@ use datafusion_common::{ // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::replace_sort_expressions; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -884,8 +887,12 @@ impl LogicalPlan { Aggregate::try_new(Arc::new(inputs.swap_remove(0)), expr, agg_expr) .map(LogicalPlan::Aggregate) } - LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { - expr, + LogicalPlan::Sort(Sort { + expr: sort_expr, + fetch, + .. + }) => Ok(LogicalPlan::Sort(Sort { + expr: replace_sort_expressions(sort_expr.clone(), expr), input: Arc::new(inputs.swap_remove(0)), fetch: *fetch, })), @@ -1014,14 +1021,11 @@ impl LogicalPlan { }) => { let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); let select_expr = expr.split_off(on_expr.len()); + assert!(sort_expr.is_empty(), "with_new_exprs for Distinct does not support sort expressions"); Distinct::On(DistinctOn::try_new( expr, select_expr, - if !sort_expr.is_empty() { - Some(sort_expr) - } else { - None - }, + None, // no sort expressions accepted Arc::new(inputs.swap_remove(0)), )?) } @@ -2559,7 +2563,7 @@ pub struct DistinctOn { /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when /// present. Note that those matching expressions actually wrap the `ON` expressions with /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). - pub sort_expr: Option>, + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, /// The schema description of the DISTINCT ON output @@ -2571,7 +2575,7 @@ impl DistinctOn { pub fn try_new( on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, input: Arc, ) -> Result { if on_expr.is_empty() { @@ -2606,20 +2610,15 @@ impl DistinctOn { /// Try to update `self` with a new sort expressions. /// /// Validates that the sort expressions are a super-set of the `ON` expressions. - pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { - let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_sorts(sort_expr, self.input.as_ref())?; // Check that the left-most sort expressions are the same as the `ON` expressions. let mut matched = true; for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { - match sort { - Expr::Sort(SortExpr { expr, .. }) => { - if on != &**expr { - matched = false; - break; - } - } - _ => return plan_err!("Not a sort expression: {sort}"), + if on != &*sort.expr { + matched = false; + break; } } @@ -2833,7 +2832,7 @@ fn calc_func_dependencies_for_project( #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Sort { /// The sort expressions - pub expr: Vec, + pub expr: Vec, /// The incoming logical plan pub input: Arc, /// Optional fetch limit diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 273404c8df31..29a99a8e8886 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::transform_option_vec; +use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -481,7 +481,9 @@ impl LogicalPlan { .apply_until_stop(|e| f(&e))? .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), + LogicalPlan::Sort(Sort { expr, .. }) => { + expr.iter().apply_until_stop(|sort| f(&sort.expr)) + } LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -507,7 +509,7 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten()) + .chain(sort_expr.iter().flatten().map(|sort| &*sort.expr)) .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) @@ -658,10 +660,10 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + transform_sort_vec(expr, &mut f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) + } LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -709,7 +711,7 @@ impl LogicalPlan { select_expr, select_expr.into_iter().map_until_stop_and_collect(&mut f), sort_expr, - transform_option_vec(sort_expr, &mut f) + transform_sort_option_vec(sort_expr, &mut f) )? .update_data(|(on_expr, select_expr, sort_expr)| { LogicalPlan::Distinct(Distinct::On(DistinctOn { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 450ebb6c2275..90d61bf63763 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -48,7 +48,6 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), @@ -98,7 +97,7 @@ impl TreeNode for Expr { expr_vec.push(f.as_ref()); } if let Some(order_by) = order_by { - expr_vec.extend(order_by); + expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref())); } expr_vec } @@ -110,7 +109,7 @@ impl TreeNode for Expr { }) => { let mut expr_vec = args.iter().collect::>(); expr_vec.extend(partition_by); - expr_vec.extend(order_by); + expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref())); expr_vec } Expr::InList(InList { expr, list, .. }) => { @@ -265,12 +264,6 @@ impl TreeNode for Expr { .update_data(|be| Expr::Cast(Cast::new(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), Expr::ScalarFunction(ScalarFunction { func, args }) => { transform_vec(args, &mut f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( @@ -290,7 +283,7 @@ impl TreeNode for Expr { partition_by, transform_vec(partition_by, &mut f), order_by, - transform_vec(order_by, &mut f) + transform_sort_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { Expr::WindowFunction(WindowFunction::new(fun, new_args)) @@ -313,7 +306,7 @@ impl TreeNode for Expr { filter, transform_option_box(filter, &mut f), order_by, - transform_option_vec(order_by, &mut f) + transform_sort_option_vec(order_by, &mut f) )? .map_data(|(new_args, new_filter, new_order_by)| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( @@ -386,3 +379,41 @@ fn transform_vec Result>>( ) -> Result>> { ve.into_iter().map_until_stop_and_collect(f) } + +pub fn transform_sort_option_vec Result>>( + sorts_option: Option>, + f: &mut F, +) -> Result>>> { + sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { + Ok(transform_sort_vec(sorts, f)?.update_data(Some)) + }) +} + +pub fn transform_sort_vec Result>>( + sorts: Vec, + mut f: &mut F, +) -> Result>> { + Ok(sorts + .iter() + .map(|sort| (*sort.expr).clone()) + .map_until_stop_and_collect(&mut f)? + .update_data(|transformed_exprs| { + replace_sort_expressions(sorts, transformed_exprs) + })) +} + +pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { + assert_eq!(sorts.len(), new_expr.len()); + sorts + .into_iter() + .zip(new_expr) + .map(|(sort, expr)| replace_sort_expression(sort, expr)) + .collect() +} + +pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { + Sort { + expr: Box::new(new_expr), + ..sort + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a01d5ef8973a..b6b1b5660a81 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -296,7 +296,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::Sort { .. } | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } @@ -461,22 +460,20 @@ pub fn expand_qualified_wildcard( /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") /// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column -type WindowSortKey = Vec<(Expr, bool)>; +type WindowSortKey = Vec<(Sort, bool)>; /// Generate a sort key for a given window expr's partition_by and order_by expr pub fn generate_sort_key( partition_by: &[Expr], - order_by: &[Expr], + order_by: &[Sort], ) -> Result { let normalized_order_by_keys = order_by .iter() - .map(|e| match e { - Expr::Sort(Sort { expr, .. }) => { - Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) - } - _ => plan_err!("Order by only accepts sort expressions"), + .map(|e| { + let Sort { expr, .. } = e; + Sort::new(expr.clone(), true, false) }) - .collect::>>()?; + .collect::>(); let mut final_sort_keys = vec![]; let mut is_partition_flag = vec![]; @@ -512,65 +509,61 @@ pub fn generate_sort_key( /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): /// pub fn compare_sort_expr( - sort_expr_a: &Expr, - sort_expr_b: &Expr, + sort_expr_a: &Sort, + sort_expr_b: &Sort, schema: &DFSchemaRef, ) -> Ordering { - match (sort_expr_a, sort_expr_b) { - ( - Expr::Sort(Sort { - expr: expr_a, - asc: asc_a, - nulls_first: nulls_first_a, - }), - Expr::Sort(Sort { - expr: expr_b, - asc: asc_b, - nulls_first: nulls_first_b, - }), - ) => { - let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); - let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); - for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { - match idx_a.cmp(idx_b) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - } + let Sort { + expr: expr_a, + asc: asc_a, + nulls_first: nulls_first_a, + } = sort_expr_a; + + let Sort { + expr: expr_b, + asc: asc_b, + nulls_first: nulls_first_b, + } = sort_expr_b; + + let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); + let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); + for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { + match idx_a.cmp(idx_b) { + Ordering::Less => { + return Ordering::Less; } - match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { - Ordering::Less => return Ordering::Greater, - Ordering::Greater => { - return Ordering::Less; - } - Ordering::Equal => {} + Ordering::Greater => { + return Ordering::Greater; } - match (asc_a, asc_b) { - (true, false) => { - return Ordering::Greater; - } - (false, true) => { - return Ordering::Less; - } - _ => {} - } - match (nulls_first_a, nulls_first_b) { - (true, false) => { - return Ordering::Less; - } - (false, true) => { - return Ordering::Greater; - } - _ => {} - } - Ordering::Equal + Ordering::Equal => {} } - _ => panic!("Sort expressions must be of type Sort"), } + match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { + Ordering::Less => return Ordering::Greater, + Ordering::Greater => { + return Ordering::Less; + } + Ordering::Equal => {} + } + match (asc_a, asc_b) { + (true, false) => { + return Ordering::Greater; + } + (false, true) => { + return Ordering::Less; + } + _ => {} + } + match (nulls_first_a, nulls_first_b) { + (true, false) => { + return Ordering::Less; + } + (false, true) => { + return Ordering::Greater; + } + _ => {} + } + Ordering::Equal } /// group a slice of window expression expr by their order by expressions @@ -606,14 +599,6 @@ pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { }) } -/// Collect all deeply nested `Expr::Sort`. They are returned in order of occurrence -/// (depth first), with duplicates omitted. -pub fn find_sort_exprs(exprs: &[Expr]) -> Vec { - find_exprs_in_exprs(exprs, &|nested_expr| { - matches!(nested_expr, Expr::Sort { .. }) - }) -} - /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. pub fn find_window_exprs(exprs: &[Expr]) -> Vec { @@ -1376,8 +1361,7 @@ mod tests { use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFrame, - WindowFunctionDefinition, + test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; #[test] @@ -1417,10 +1401,9 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)); - let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)); - let created_at_desc = - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); + let age_asc = expr::Sort::new(Box::new(col("age")), true, true); + let name_desc = expr::Sort::new(Box::new(col("name")), false, true); + let created_at_desc = expr::Sort::new(Box::new(col("created_at")), false, true); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], @@ -1471,43 +1454,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_sort_exprs() -> Result<()> { - let exprs = &[ - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - vec![col("name")], - )) - .order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - )) - .order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), - ]; - let expected = vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]; - let result = find_sort_exprs(exprs); - assert_eq!(expected, result); - Ok(()) - } - #[test] fn avoid_generate_duplicate_sort_keys() -> Result<()> { let asc_or_desc = [true, false]; @@ -1516,41 +1462,41 @@ mod tests { for asc_ in asc_or_desc { for nulls_first_ in nulls_first_or_last { let order_by = &[ - Expr::Sort(Sort { + Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }), - Expr::Sort(Sort { + }, + Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }), + }, ]; let expected = vec![ ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("created_at")), asc: true, nulls_first: false, - }), + }, true, ), ]; diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 0e1d917419f8..6c935cdcd121 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -26,7 +26,7 @@ use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{lit, Expr}; +use crate::{expr::Sort, lit}; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; @@ -247,7 +247,7 @@ impl WindowFrame { } /// Regularizes the ORDER BY clause of the window frame. - pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { + pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { match self.units { // Normally, RANGE frames require an ORDER BY clause with exactly // one column. However, an ORDER BY clause may be absent or have diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 2162442f054e..30f5d5b07561 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -32,7 +32,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, - Signature, TypeSignature, Volatility, + Signature, SortExpr, TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -40,7 +40,7 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; create_func!(FirstValue, first_value_udaf); /// Returns the first value in a group of values. -pub fn first_value(expression: Expr, order_by: Option>) -> Expr { +pub fn first_value(expression: Expr, order_by: Option>) -> Expr { if let Some(order_by) = order_by { first_value_udaf() .call(vec![expression]) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index e114efb99960..35d4f91e3b6f 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -229,7 +229,7 @@ mod tests { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a6b9bad6c5d9..61ff4b4fd5a8 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, - ScalarFunction, WindowFunction, + ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -506,7 +506,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { | Expr::Negative(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::Sort(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) @@ -593,12 +592,12 @@ fn coerce_frame_bound( fn coerce_window_frame( window_frame: WindowFrame, schema: &DFSchema, - expressions: &[Expr], + expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; let current_types = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|s| s.expr.get_type(schema)) .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 3a2b190359d4..25bef7e2d0e4 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -36,6 +36,7 @@ use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; +use datafusion_expr::tree_node::replace_sort_expressions; use datafusion_expr::{col, BinaryExpr, Case, Expr, ExprSchemable, Operator}; use indexmap::IndexMap; @@ -327,15 +328,17 @@ impl CommonSubexprEliminate { ) -> Result> { let Sort { expr, input, fetch } = sort; let input = Arc::unwrap_or_clone(input); - let new_sort = self.try_unary_plan(expr, input, config)?.update_data( - |(new_expr, new_input)| { + let sort_expressions = + expr.iter().map(|sort| sort.expr.as_ref().clone()).collect(); + let new_sort = self + .try_unary_plan(sort_expressions, input, config)? + .update_data(|(new_expr, new_input)| { LogicalPlan::Sort(Sort { - expr: new_expr, + expr: replace_sort_expressions(expr, new_expr), input: Arc::new(new_input), fetch, }) - }, - ); + }); Ok(new_sort) } @@ -882,7 +885,6 @@ enum ExprMask { /// - [`Columns`](Expr::Column) /// - [`ScalarVariable`](Expr::ScalarVariable) /// - [`Alias`](Expr::Alias) - /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) Normal, @@ -899,7 +901,6 @@ impl ExprMask { | Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Alias(..) - | Expr::Sort { .. } | Expr::Wildcard { .. } ); diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index e9d091d52b00..c460d7a93d26 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -22,7 +22,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Aggregate, Expr, Sort}; +use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; use indexmap::IndexSet; use std::hash::{Hash, Hasher}; /// Optimization rule that eliminate duplicated expr. @@ -37,29 +37,15 @@ impl EliminateDuplicatedExpr { } // use this structure to avoid initial clone #[derive(Eq, Clone, Debug)] -struct SortExprWrapper { - expr: Expr, -} +struct SortExprWrapper(SortExpr); impl PartialEq for SortExprWrapper { fn eq(&self, other: &Self) -> bool { - match (&self.expr, &other.expr) { - (Expr::Sort(own_sort), Expr::Sort(other_sort)) => { - own_sort.expr == other_sort.expr - } - _ => self.expr == other.expr, - } + self.0.expr == other.0.expr } } impl Hash for SortExprWrapper { fn hash(&self, state: &mut H) { - match &self.expr { - Expr::Sort(sort) => { - sort.expr.hash(state); - } - _ => { - self.expr.hash(state); - } - } + self.0.expr.hash(state); } } impl OptimizerRule for EliminateDuplicatedExpr { @@ -82,10 +68,10 @@ impl OptimizerRule for EliminateDuplicatedExpr { let unique_exprs: Vec<_> = sort .expr .into_iter() - .map(|e| SortExprWrapper { expr: e }) + .map(SortExprWrapper) .collect::>() .into_iter() - .map(|wrapper| wrapper.expr) + .map(|wrapper| wrapper.0) .collect(); let transformed = if len != unique_exprs.len() { @@ -146,11 +132,11 @@ mod tests { fn eliminate_sort_expr() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a"), col("a"), col("b"), col("c")])? + .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, test.b, test.c\ + \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index e48f37a77cd3..2503475bd8df 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -182,14 +182,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(2, Some(1))? .build()?; // After remove global-state, we don't record the parent // So, bottom don't know parent info, so can't eliminate. let expected = "Limit: skip=2, fetch=1\ - \n Sort: test.a, fetch=3\ + \n Sort: test.a ASC NULLS LAST, fetch=3\ \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; @@ -202,12 +202,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(1))? .build()?; let expected = "Limit: skip=0, fetch=1\ - \n Sort: test.a\ + \n Sort: test.a ASC NULLS LAST\ \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; @@ -220,12 +220,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(2, Some(1))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(3, Some(1))? .build()?; let expected = "Limit: skip=3, fetch=1\ - \n Sort: test.a\ + \n Sort: test.a ASC NULLS LAST\ \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 82149a087e63..33a58a810b08 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -284,8 +284,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::TryCast(_) | Expr::InList { .. } | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), - Expr::Sort(_) - | Expr::AggregateFunction(_) + Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index dff0b61c6b22..ab7880213692 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -347,13 +347,13 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(10))? .build()?; // Should push down limit to sort let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a, fetch=10\ + \n Sort: test.a ASC NULLS LAST, fetch=10\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) @@ -364,13 +364,13 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(5, Some(10))? .build()?; // Should push down limit to sort let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, fetch=15\ + \n Sort: test.a ASC NULLS LAST, fetch=15\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 7129ceb0fea1..f299d4542c36 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -591,7 +591,6 @@ impl<'a> ConstEvaluator<'a> { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } - | Expr::Sort { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 30cae17eaf9f..dd82b056d0a6 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -624,14 +624,14 @@ mod tests { vec![col("a")], false, None, - Some(vec![col("a")]), + Some(vec![col("a").sort(true, false)]), None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 826992e132ba..19759a897068 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -75,6 +75,10 @@ message LogicalExprNodeCollection { repeated LogicalExprNode logical_expr_nodes = 1; } +message SortExprNodeCollection { + repeated SortExprNode sort_expr_nodes = 1; +} + message ListingTableScanNode { reserved 1; // was string table_name TableReference table_name = 14; @@ -92,7 +96,7 @@ message ListingTableScanNode { datafusion_common.AvroFormat avro = 12; datafusion_common.NdJsonFormat json = 15; } - repeated LogicalExprNodeCollection file_sort_order = 13; + repeated SortExprNodeCollection file_sort_order = 13; } message ViewTableScanNode { @@ -129,7 +133,7 @@ message SelectionNode { message SortNode { LogicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated SortExprNode expr = 2; // Maximum number of highest/lowest rows to fetch; negative means no limit int64 fetch = 3; } @@ -160,7 +164,7 @@ message CreateExternalTableNode { repeated string table_partition_cols = 5; bool if_not_exists = 6; string definition = 7; - repeated LogicalExprNodeCollection order_exprs = 10; + repeated SortExprNodeCollection order_exprs = 10; bool unbounded = 11; map options = 8; datafusion_common.Constraints constraints = 12; @@ -245,7 +249,7 @@ message DistinctNode { message DistinctOnNode { repeated LogicalExprNode on_expr = 1; repeated LogicalExprNode select_expr = 2; - repeated LogicalExprNode sort_expr = 3; + repeated SortExprNode sort_expr = 3; LogicalPlanNode input = 4; } @@ -320,7 +324,6 @@ message LogicalExprNode { BetweenNode between = 9; CaseNode case_ = 10; CastNode cast = 11; - SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; Wildcard wildcard = 15; @@ -470,7 +473,7 @@ message AggregateUDFExprNode { repeated LogicalExprNode args = 2; bool distinct = 5; LogicalExprNode filter = 3; - repeated LogicalExprNode order_by = 4; + repeated SortExprNode order_by = 4; optional bytes fun_definition = 6; } @@ -503,7 +506,7 @@ message WindowExprNode { } LogicalExprNode expr = 4; repeated LogicalExprNode partition_by = 5; - repeated LogicalExprNode order_by = 6; + repeated SortExprNode order_by = 6; // repeated LogicalExprNode filter = 7; WindowFrame window_frame = 8; optional bytes fun_definition = 10; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b4d63798f080..cff58d3ddc4a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9291,9 +9291,6 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::Cast(v) => { struct_ser.serialize_field("cast", v)?; } - logical_expr_node::ExprType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } logical_expr_node::ExprType::Negative(v) => { struct_ser.serialize_field("negative", v)?; } @@ -9384,7 +9381,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "case_", "case", "cast", - "sort", "negative", "in_list", "inList", @@ -9433,7 +9429,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { Between, Case, Cast, - Sort, Negative, InList, Wildcard, @@ -9486,7 +9481,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "between" => Ok(GeneratedField::Between), "case" | "case_" => Ok(GeneratedField::Case), "cast" => Ok(GeneratedField::Cast), - "sort" => Ok(GeneratedField::Sort), "negative" => Ok(GeneratedField::Negative), "inList" | "in_list" => Ok(GeneratedField::InList), "wildcard" => Ok(GeneratedField::Wildcard), @@ -9598,13 +9592,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("cast")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) -; - } - GeneratedField::Sort => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) ; } GeneratedField::Negative => { @@ -17947,6 +17934,98 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { deserializer.deserialize_struct("datafusion.SortExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SortExprNodeCollection { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.sort_expr_nodes.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SortExprNodeCollection", len)?; + if !self.sort_expr_nodes.is_empty() { + struct_ser.serialize_field("sortExprNodes", &self.sort_expr_nodes)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SortExprNodeCollection { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "sort_expr_nodes", + "sortExprNodes", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + SortExprNodes, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "sortExprNodes" | "sort_expr_nodes" => Ok(GeneratedField::SortExprNodes), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SortExprNodeCollection; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SortExprNodeCollection") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut sort_expr_nodes__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::SortExprNodes => { + if sort_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExprNodes")); + } + sort_expr_nodes__ = Some(map_.next_value()?); + } + } + } + Ok(SortExprNodeCollection { + sort_expr_nodes: sort_expr_nodes__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SortExprNodeCollection", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for SortNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 875d2af75dd7..2ce8004e3248 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -97,6 +97,12 @@ pub struct LogicalExprNodeCollection { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SortExprNodeCollection { + #[prost(message, repeated, tag = "1")] + pub sort_expr_nodes: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ListingTableScanNode { #[prost(message, optional, tag = "14")] pub table_name: ::core::option::Option, @@ -117,7 +123,7 @@ pub struct ListingTableScanNode { #[prost(uint32, tag = "9")] pub target_partitions: u32, #[prost(message, repeated, tag = "13")] - pub file_sort_order: ::prost::alloc::vec::Vec, + pub file_sort_order: ::prost::alloc::vec::Vec, #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12, 15")] pub file_format_type: ::core::option::Option< listing_table_scan_node::FileFormatType, @@ -200,7 +206,7 @@ pub struct SortNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub expr: ::prost::alloc::vec::Vec, + pub expr: ::prost::alloc::vec::Vec, /// Maximum number of highest/lowest rows to fetch; negative means no limit #[prost(int64, tag = "3")] pub fetch: i64, @@ -256,7 +262,7 @@ pub struct CreateExternalTableNode { #[prost(string, tag = "7")] pub definition: ::prost::alloc::string::String, #[prost(message, repeated, tag = "10")] - pub order_exprs: ::prost::alloc::vec::Vec, + pub order_exprs: ::prost::alloc::vec::Vec, #[prost(bool, tag = "11")] pub unbounded: bool, #[prost(map = "string, string", tag = "8")] @@ -402,7 +408,7 @@ pub struct DistinctOnNode { #[prost(message, repeated, tag = "2")] pub select_expr: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "3")] - pub sort_expr: ::prost::alloc::vec::Vec, + pub sort_expr: ::prost::alloc::vec::Vec, #[prost(message, optional, boxed, tag = "4")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, } @@ -488,7 +494,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" )] pub expr_type: ::core::option::Option, } @@ -521,8 +527,6 @@ pub mod logical_expr_node { Case(::prost::alloc::boxed::Box), #[prost(message, tag = "11")] Cast(::prost::alloc::boxed::Box), - #[prost(message, tag = "12")] - Sort(::prost::alloc::boxed::Box), #[prost(message, tag = "13")] Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] @@ -740,7 +744,7 @@ pub struct AggregateUdfExprNode { #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", optional, tag = "6")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } @@ -762,7 +766,7 @@ pub struct WindowExprNode { #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, @@ -869,8 +873,8 @@ pub struct TryCastNode { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "1")] + pub expr: ::core::option::Option, #[prost(bool, tag = "2")] pub asc: bool, #[prost(bool, tag = "3")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index acda1298dd80..3ba1cb945e9c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,11 +22,11 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::{Alias, Placeholder}; +use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, InList, Sort, WindowFunction}, + expr::{self, InList, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, @@ -267,7 +267,7 @@ pub fn parse_expr( .as_ref() .ok_or_else(|| Error::required("window_function"))?; let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_exprs(&expr.order_by, registry, codec)?; + let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; let window_frame = expr .window_frame .as_ref() @@ -524,16 +524,6 @@ pub fn parse_expr( let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast(TryCast::new(expr, data_type))) } - ExprType::Sort(sort) => Ok(Expr::Sort(Sort::new( - Box::new(parse_required_expr( - sort.expr.as_deref(), - registry, - "expr", - codec, - )?), - sort.asc, - sort.nulls_first, - ))), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), @@ -588,7 +578,7 @@ pub fn parse_expr( parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), match pb.order_by.len() { 0 => None, - _ => Some(parse_exprs(&pb.order_by, registry, codec)?), + _ => Some(parse_sorts(&pb.order_by, registry, codec)?), }, None, ))) @@ -635,6 +625,37 @@ where Ok(res) } +pub fn parse_sorts<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + protos + .into_iter() + .map(|sort| parse_sort(sort, registry, codec)) + .collect::, Error>>() +} + +pub fn parse_sort( + sort: &protobuf::SortExprNode, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result { + Ok(Sort::new( + Box::new(parse_required_expr( + sort.expr.as_ref(), + registry, + "expr", + codec, + )?), + sort.asc, + sort.nulls_first, + )) +} + /// Parse an optional escape_char for Like, ILike, SimilarTo fn parse_escape_char(s: &str) -> Result> { match s.len() { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 67977b1795a6..bf5394ec01de 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{CustomTableScanNode, SortExprNodeCollection}; use crate::{ convert_required, into_required, protobuf::{ @@ -62,11 +62,13 @@ use datafusion_expr::{ EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, WindowUDF, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, + WindowUDF, }; use datafusion_expr::{AggregateUDF, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; +use crate::logical_plan::to_proto::serialize_sorts; use prost::bytes::BufMut; use prost::Message; @@ -347,8 +349,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut all_sort_orders = vec![]; for order in &scan.file_sort_order { - all_sort_orders.push(from_proto::parse_exprs( - &order.logical_expr_nodes, + all_sort_orders.push(from_proto::parse_sorts( + &order.sort_expr_nodes, ctx, extension_codec, )?) @@ -476,8 +478,8 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Sort(sort) => { let input: LogicalPlan = into_logical_plan!(sort.input, ctx, extension_codec)?; - let sort_expr: Vec = - from_proto::parse_exprs(&sort.expr, ctx, extension_codec)?; + let sort_expr: Vec = + from_proto::parse_sorts(&sort.expr, ctx, extension_codec)?; LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } LogicalPlanType::Repartition(repartition) => { @@ -536,8 +538,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut order_exprs = vec![]; for expr in &create_extern_table.order_exprs { - order_exprs.push(from_proto::parse_exprs( - &expr.logical_expr_nodes, + order_exprs.push(from_proto::parse_sorts( + &expr.sort_expr_nodes, ctx, extension_codec, )?); @@ -772,7 +774,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let sort_expr = match distinct_on.sort_expr.len() { 0 => None, - _ => Some(from_proto::parse_exprs( + _ => Some(from_proto::parse_sorts( &distinct_on.sort_expr, ctx, extension_codec, @@ -981,10 +983,10 @@ impl AsLogicalPlan for LogicalPlanNode { let options = listing_table.options(); - let mut exprs_vec: Vec = vec![]; + let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { - let expr_vec = LogicalExprNodeCollection { - logical_expr_nodes: serialize_exprs(order, extension_codec)?, + let expr_vec = SortExprNodeCollection { + sort_expr_nodes: serialize_sorts(order, extension_codec)?, }; exprs_vec.push(expr_vec); } @@ -1114,7 +1116,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let sort_expr = match sort_expr { None => vec![], - Some(sort_expr) => serialize_exprs(sort_expr, extension_codec)?, + Some(sort_expr) => serialize_sorts(sort_expr, extension_codec)?, }; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( @@ -1258,13 +1260,13 @@ impl AsLogicalPlan for LogicalPlanNode { input.as_ref(), extension_codec, )?; - let selection_expr: Vec = - serialize_exprs(expr, extension_codec)?; + let sort_expr: Vec = + serialize_sorts(expr, extension_codec)?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( protobuf::SortNode { input: Some(Box::new(input)), - expr: selection_expr, + expr: sort_expr, fetch: fetch.map(|f| f as i64).unwrap_or(-1i64), }, ))), @@ -1334,10 +1336,10 @@ impl AsLogicalPlan for LogicalPlanNode { column_defaults, }, )) => { - let mut converted_order_exprs: Vec = vec![]; + let mut converted_order_exprs: Vec = vec![]; for order in order_exprs { - let temp = LogicalExprNodeCollection { - logical_expr_nodes: serialize_exprs(order, extension_codec)?, + let temp = SortExprNodeCollection { + sort_expr_nodes: serialize_sorts(order, extension_codec)?, }; converted_order_exprs.push(temp); } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index bb7bf84a3387..b937c03f79d9 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,12 +22,12 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, - ScalarFunction, Sort, Unnest, + ScalarFunction, Unnest, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, BuiltInWindowFunction, Expr, - JoinConstraint, JoinType, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; use crate::protobuf::{ @@ -343,7 +343,7 @@ pub fn serialize_expr( None }; let partition_by = serialize_exprs(partition_by, codec)?; - let order_by = serialize_exprs(order_by, codec)?; + let order_by = serialize_sorts(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); @@ -380,7 +380,7 @@ pub fn serialize_expr( None => None, }, order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, + Some(e) => serialize_sorts(e, codec)?, None => vec![], }, fun_definition: (!buf.is_empty()).then_some(buf), @@ -537,20 +537,6 @@ pub fn serialize_expr( expr_type: Some(ExprType::TryCast(expr)), } } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = Box::new(protobuf::SortExprNode { - expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - asc: *asc, - nulls_first: *nulls_first, - }); - protobuf::LogicalExprNode { - expr_type: Some(ExprType::Sort(expr)), - } - } Expr::Negative(expr) => { let expr = Box::new(protobuf::NegativeNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), @@ -635,6 +621,30 @@ pub fn serialize_expr( Ok(expr_node) } +pub fn serialize_sorts<'a, I>( + sorts: I, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + sorts + .into_iter() + .map(|sort| { + let SortExpr { + expr, + asc, + nulls_first, + } = sort; + Ok(protobuf::SortExprNode { + expr: Some(serialize_expr(expr.as_ref(), codec)?), + asc: *asc, + nulls_first: *nulls_first, + }) + }) + .collect::, Error>>() +} + impl From for protobuf::TableReference { fn from(t: TableReference) -> Self { use protobuf::table_reference::TableReferenceEnum; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 94ac913e1968..e174d1b50713 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,7 +59,7 @@ use datafusion_common::{ use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - Sort, Unnest, WildcardOptions, + Unnest, WildcardOptions, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ @@ -1937,14 +1937,6 @@ fn roundtrip_try_cast() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_sort_expr() { - let test_expr = Expr::Sort(Sort::new(Box::new(lit(1.0_f32)), true, true)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - #[test] fn roundtrip_negative() { let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 71e40c20b80a..9c768eb73c2e 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -282,22 +282,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let func_deps = schema.functional_dependencies(); // Find whether ties are possible in the given ordering let is_ordering_strict = order_by.iter().find_map(|orderby_expr| { - if let Expr::Sort(sort_expr) = orderby_expr { - if let Expr::Column(col) = sort_expr.expr.as_ref() { - let idx = schema.index_of_column(col).ok()?; - return if func_deps.iter().any(|dep| { - dep.source_indices == vec![idx] - && dep.mode == Dependency::Single - }) { - Some(true) - } else { - Some(false) - }; - } - Some(false) - } else { - panic!("order_by expression must be of type Sort"); + if let Expr::Column(col) = orderby_expr.expr.as_ref() { + let idx = schema.index_of_column(col).ok()?; + return if func_deps.iter().any(|dep| { + dep.source_indices == vec![idx] && dep.mode == Dependency::Single + }) { + Some(true) + } else { + Some(false) + }; } + Some(false) }); let window_frame = window diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 7fb32f714cfa..cdaa787cedd0 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -20,7 +20,7 @@ use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, }; use datafusion_expr::expr::Sort; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -44,7 +44,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, literal_to_column: bool, additional_schema: Option<&DFSchema>, - ) -> Result> { + ) -> Result> { if exprs.is_empty() { return Ok(vec![]); } @@ -99,13 +99,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; let asc = asc.unwrap_or(true); - expr_vec.push(Expr::Sort(Sort::new( + expr_vec.push(Sort::new( Box::new(expr), asc, // when asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html nulls_first.unwrap_or(!asc), - ))) + )) } Ok(expr_vec) } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ba2b41bb6ecf..71328cfd018c 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; +use datafusion_expr::expr::Sort; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, Operator, @@ -119,7 +120,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn order_by( &self, plan: LogicalPlan, - order_by: Vec, + order_by: Vec, ) -> Result { if order_by.is_empty() { return Ok(plan); diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 384893bfa94c..8a26671fcb6c 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -31,7 +31,7 @@ use datafusion_common::UnnestOptions; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ - normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, }; use datafusion_expr::utils::{ expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, @@ -107,7 +107,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { true, Some(base_plan.schema().as_ref()), )?; - let order_by_rex = normalize_cols(order_by_rex, &projected_plan)?; + let order_by_rex = normalize_sorts(order_by_rex, &projected_plan)?; // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index e75a96e78d48..3dfc379b039a 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -48,9 +48,10 @@ use datafusion_expr::{ CreateIndex as PlanCreateIndex, CreateMemoryTable, CreateView, DescribeTable, DmlStatement, DropCatalogSchema, DropFunction, DropTable, DropView, EmptyRelation, Explain, Expr, ExprSchemable, Filter, LogicalPlan, LogicalPlanBuilder, - OperateFunctionArg, PlanType, Prepare, SetVariable, Statement as PlanStatement, - ToStringifiedPlan, TransactionAccessMode, TransactionConclusion, TransactionEnd, - TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, + OperateFunctionArg, PlanType, Prepare, SetVariable, SortExpr, + Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, + TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, + Volatility, WriteOp, }; use sqlparser::ast; use sqlparser::ast::{ @@ -952,7 +953,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: Vec, schema: &DFSchemaRef, planner_context: &mut PlannerContext, - ) -> Result>> { + ) -> Result>> { // Ask user to provide a schema if schema is empty. if !order_exprs.is_empty() && schema.fields().is_empty() { return plan_err!( @@ -966,8 +967,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let expr_vec = self.order_by_to_sort_expr(expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: - for expr in expr_vec.iter() { - for column in expr.column_refs().iter() { + for sort in expr_vec.iter() { + for column in sort.expr.column_refs().iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: return plan_err!("Column {column} is not in schema"); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 0dbcba162bc0..9a3f139fdee8 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use core::fmt; - use datafusion_expr::ScalarUDF; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ @@ -24,7 +22,7 @@ use sqlparser::ast::{ ObjectName, TimezoneInfo, UnaryOperator, }; use std::sync::Arc; -use std::{fmt::Display, vec}; +use std::vec; use super::dialect::{DateFieldExtractStyle, IntervalStyle}; use super::Unparser; @@ -46,33 +44,6 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -/// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` -pub enum Unparsed { - // SQL Expression - Expr(ast::Expr), - // SQL ORDER BY expression (e.g. `col ASC NULLS FIRST`) - OrderByExpr(ast::OrderByExpr), -} - -impl Unparsed { - pub fn into_order_by_expr(self) -> Result { - if let Unparsed::OrderByExpr(order_by_expr) = self { - Ok(order_by_expr) - } else { - internal_err!("Expected Sort expression to be converted an OrderByExpr") - } - } -} - -impl Display for Unparsed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Unparsed::Expr(expr) => write!(f, "{}", expr), - Unparsed::OrderByExpr(order_by_expr) => write!(f, "{}", order_by_expr), - } - } -} - /// Convert a DataFusion [`Expr`] to [`ast::Expr`] /// /// This function is the opposite of [`SqlToRel::sql_to_expr`] and can be used @@ -106,13 +77,9 @@ pub fn expr_to_sql(expr: &Expr) -> Result { unparser.expr_to_sql(expr) } -/// Convert a DataFusion [`Expr`] to [`Unparsed`] -/// -/// This function is similar to expr_to_sql, but it supports converting more [`Expr`] types like -/// `Sort` expressions to `OrderByExpr` expressions. -pub fn expr_to_unparsed(expr: &Expr) -> Result { +pub fn sort_to_sql(sort: &Sort) -> Result { let unparser = Unparser::default(); - unparser.expr_to_unparsed(expr) + unparser.sort_to_sql(sort) } const LOWEST: &BinaryOperator = &BinaryOperator::Or; @@ -286,7 +253,7 @@ impl Unparser<'_> { }; let order_by: Vec = order_by .iter() - .map(|expr| expr_to_unparsed(expr)?.into_order_by_expr()) + .map(sort_to_sql) .collect::>>()?; let start_bound = self.convert_bound(&window_frame.start_bound)?; @@ -413,11 +380,6 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::Sort(Sort { - expr: _, - asc: _, - nulls_first: _, - }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } @@ -534,36 +496,26 @@ impl Unparser<'_> { } } - /// This function can convert more [`Expr`] types than `expr_to_sql`, - /// returning an [`Unparsed`] like `Sort` expressions to `OrderByExpr` - /// expressions. - pub fn expr_to_unparsed(&self, expr: &Expr) -> Result { - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + pub fn sort_to_sql(&self, sort: &Sort) -> Result { + let Sort { + expr, + asc, + nulls_first, + } = sort; + let sql_parser_expr = self.expr_to_sql(expr)?; - let nulls_first = if self.dialect.supports_nulls_first_in_sort() { - Some(*nulls_first) - } else { - None - }; + let nulls_first = if self.dialect.supports_nulls_first_in_sort() { + Some(*nulls_first) + } else { + None + }; - Ok(Unparsed::OrderByExpr(ast::OrderByExpr { - expr: sql_parser_expr, - asc: Some(*asc), - nulls_first, - with_fill: None, - })) - } - _ => { - let sql_parser_expr = self.expr_to_sql(expr)?; - Ok(Unparsed::Expr(sql_parser_expr)) - } - } + Ok(ast::OrderByExpr { + expr: sql_parser_expr, + asc: Some(*asc), + nulls_first, + with_fill: None, + }) } fn scalar_function_to_sql_overrides( @@ -1809,11 +1761,7 @@ mod tests { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], - order_by: vec![Expr::Sort(Sort::new( - Box::new(col("a")), - false, - true, - ))], + order_by: vec![Sort::new(Box::new(col("a")), false, true)], window_frame: WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, datafusion_expr::WindowFrameBound::Preceding( @@ -1941,24 +1889,6 @@ mod tests { Ok(()) } - #[test] - fn expr_to_unparsed_ok() -> Result<()> { - let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), - (col("a").sort(true, true), r#"a ASC NULLS FIRST"#), - ]; - - for (expr, expected) in tests { - let ast = expr_to_unparsed(&expr)?; - - let actual = format!("{}", ast); - - assert_eq!(actual, expected); - } - - Ok(()) - } - #[test] fn custom_dialect_with_identifier_quote_style() -> Result<()> { let dialect = CustomDialectBuilder::new() @@ -2047,7 +1977,7 @@ mod tests { #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { - let tests: Vec<(Expr, &str, bool)> = vec![ + let tests: Vec<(Sort, &str, bool)> = vec![ (col("a").sort(true, true), r#"a ASC NULLS FIRST"#, true), (col("a").sort(true, true), r#"a ASC"#, false), ]; @@ -2057,7 +1987,7 @@ mod tests { .with_supports_nulls_first_in_sort(supports_nulls_first_in_sort) .build(); let unparser = Unparser::new(&dialect); - let ast = unparser.expr_to_unparsed(&expr)?; + let ast = unparser.sort_to_sql(&expr)?; let actual = format!("{}", ast); diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index b2fd32566aa8..83ae64ba238b 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,8 +29,6 @@ pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; -pub use expr::Unparsed; - /// Convert a DataFusion [`Expr`] to [`sqlparser::ast::Expr`] /// /// See [`expr_to_sql`] for background. `Unparser` allows greater control of diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 106705c322fc..509c5dd52cd4 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{ - internal_err, not_impl_err, plan_err, Column, DataFusionError, Result, -}; +use datafusion_common::{internal_err, not_impl_err, Column, DataFusionError, Result}; use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, + SortExpr, }; use sqlparser::ast::{self, Ident, SetExpr}; @@ -318,7 +317,7 @@ impl Unparser<'_> { return self.derive(plan, relation); } if let Some(query_ref) = query { - query_ref.order_by(self.sort_to_sql(sort.expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -361,7 +360,7 @@ impl Unparser<'_> { .collect::>>()?; if let Some(sort_expr) = &on.sort_expr { if let Some(query_ref) = query { - query_ref.order_by(self.sort_to_sql(sort_expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -525,14 +524,10 @@ impl Unparser<'_> { } } - fn sort_to_sql(&self, sort_exprs: Vec) -> Result> { + fn sorts_to_sql(&self, sort_exprs: Vec) -> Result> { sort_exprs .iter() - .map(|expr: &Expr| { - self.expr_to_unparsed(expr)? - .into_order_by_expr() - .or(plan_err!("Expecting Sort expr")) - }) + .map(|sort_expr| self.sort_to_sql(sort_expr)) .collect::>>() } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 9e1adcf4df31..522a08af8546 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -21,10 +21,11 @@ use std::{ }; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, + tree_node::{Transformed, TransformedResult, TreeNode}, Result, }; -use datafusion_expr::{Expr, LogicalPlan, Projection, Sort}; +use datafusion_expr::tree_node::transform_sort_vec; +use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. @@ -83,20 +84,18 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result } /// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. -fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { - let sort_exprs: Vec = exprs - .into_iter() - .map_until_stop_and_collect(|expr| { - expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { - col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) - } else { - Ok(Transformed::no(expr)) - } - }) +fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + let sort_exprs = transform_sort_vec(exprs, &mut |expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } }) - .data()?; + }) + .data()?; Ok(sort_exprs) } @@ -158,12 +157,8 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .collect::>(); let mut collects = p.expr.clone(); - for expr in &sort.expr { - if let Expr::Sort(s) = expr { - collects.push(s.expr.as_ref().clone()); - } else { - panic!("sort expression must be of type Sort"); - } + for sort in &sort.expr { + collects.push(sort.expr.as_ref().clone()); } // Compare outer collects Expr::to_string with inner collected transformed values diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b1b510f1792d..05903bb56cfe 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -31,7 +31,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, - ExprSchemable, LogicalPlan, Operator, Projection, Values, + ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values, }; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; @@ -900,8 +900,8 @@ pub async fn from_substrait_sorts( substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &Extensions, -) -> Result> { - let mut sorts: Vec = vec![]; +) -> Result> { + let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) @@ -935,11 +935,11 @@ pub async fn from_substrait_sorts( None => not_impl_err!("Sort without sort kind is invalid"), }; let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Expr::Sort(Sort { + sorts.push(Sort { expr: Box::new(expr), asc, nulls_first, - })); + }); } Ok(sorts) } @@ -986,7 +986,7 @@ pub async fn from_substrait_agg_func( input_schema: &DFSchema, extensions: &Extensions, filter: Option>, - order_by: Option>, + order_by: Option>, distinct: bool, ) -> Result> { let args = diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 72b6760be29c..592390a285ba 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -808,31 +808,26 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( ctx: &SessionContext, - expr: &Expr, + sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, ) -> Result { - match expr { - Expr::Sort(sort) => { - let sort_kind = match (sort.asc, sort.nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(to_substrait_rex( - ctx, - sort.expr.deref(), - schema, - 0, - extensions, - )?), - sort_kind: Some(SortKind::Direction(sort_kind.into())), - }) - } - _ => exec_err!("expects to receive sort expression"), - } + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(to_substrait_rex( + ctx, + sort.expr.deref(), + schema, + 0, + extensions, + )?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) } /// Return Substrait scalar function with two arguments @@ -2107,30 +2102,26 @@ fn try_to_substrait_field_reference( fn substrait_sort_field( ctx: &SessionContext, - expr: &Expr, + sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, ) -> Result { - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; - let d = match (asc, nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(e), - sort_kind: Some(SortKind::Direction(d as i32)), - }) - } - _ => not_impl_err!("Expecting sort expression but got {expr:?}"), - } + let Sort { + expr, + asc, + nulls_first, + } = sort; + let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) } fn substrait_field_ref(index: usize) -> Result { diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index 3bd47ef50e51..7f3e28c255c6 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -263,14 +263,14 @@ async fn main() -> Result<()>{ let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; // Create a new DataFrame sorted by `id`, `bank_account` let new_df = df.select(vec![col("a"), col("b")])? - .sort(vec![col("a")])?; + .sort_by(vec![col("a")])?; // Build the same plan using the LogicalPlanBuilder // Similar to `SELECT a, b FROM example.csv ORDER BY a` let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; let (_state, plan) = df.into_parts(); // get the DataFrame's LogicalPlan let plan = LogicalPlanBuilder::from(plan) .project(vec![col("a"), col("b")])? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .build()?; // prove they are the same assert_eq!(new_df.logical_plan(), &plan);