From 8eb4cdb783be435509747962166b43dbf78a29a6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 8 Jun 2024 11:27:19 -0400 Subject: [PATCH] Stop copying LogicalPlan and Exprs in `CommonSubexprEliminate` --- datafusion/expr/src/logical_plan/plan.rs | 66 ++- .../optimizer/src/common_subexpr_eliminate.rs | 531 ++++++++++++------ 2 files changed, 380 insertions(+), 217 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7abe6b70b64e..52ac5daa135d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -870,37 +870,7 @@ impl LogicalPlan { LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); let predicate = expr.pop().unwrap(); - - // filter predicates should not contain aliased expressions so we remove any aliases - // before this logic was added we would have aliases within filters such as for - // benchmark q6: - // - // lineitem.l_shipdate >= Date32(\"8766\") - // AND lineitem.l_shipdate < Date32(\"9131\") - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= - // Decimal128(Some(49999999999999),30,15) - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= - // Decimal128(Some(69999999999999),30,15) - // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - - let predicate = predicate - .transform_down(|expr| { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - Expr::Alias(_) => Ok(Transformed::new( - expr.unalias(), - true, - TreeNodeRecursion::Jump, - )), - _ => Ok(Transformed::no(expr)), - } - }) - .data()?; + let predicate = Filter::remove_aliases(predicate)?.data; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) @@ -2230,6 +2200,40 @@ impl Filter { } false } + + /// Remove aliases from a predicate for use in a `Filter` + /// + /// filter predicates should not contain aliased expressions so we remove + /// any aliases. + /// + /// before this logic was added we would have aliases within filters such as + /// for benchmark q6: + /// + /// ```sql + /// lineitem.l_shipdate >= Date32(\"8766\") + /// AND lineitem.l_shipdate < Date32(\"9131\") + /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= + /// Decimal128(Some(49999999999999),30,15) + /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= + /// Decimal128(Some(69999999999999),30,15) + /// AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + /// ``` + pub fn remove_aliases(predicate: Expr) -> Result> { + predicate.transform_down(|expr| { + match expr { + Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, + TreeNodeRecursion::Jump, + )), + _ => Ok(Transformed::no(expr)), + } + }) + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6820ba04f0e9..9ad44890ef24 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,18 +20,24 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::utils::NamePreserver; use arrow::datatypes::{DataType, Field}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ - internal_err, qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result, + internal_datafusion_err, internal_err, qualified_name, Column, DFSchema, DFSchemaRef, + Result, }; use datafusion_expr::expr::Alias; -use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, +}; use datafusion_expr::{col, Expr, ExprSchemable}; use indexmap::IndexMap; @@ -127,21 +133,21 @@ impl CommonSubexprEliminate { /// Returns the rewritten expressions fn rewrite_exprs_list( &self, - exprs_list: &[&[Expr]], + exprs_list: Vec>, arrays_list: &[&[Vec<(usize, String)>]], expr_stats: &ExprStats, common_exprs: &mut CommonExprs, ) -> Result>> { exprs_list - .iter() + .into_iter() .zip(arrays_list.iter()) .map(|(exprs, arrays)| { exprs - .iter() - .cloned() + .into_iter() .zip(arrays.iter()) .map(|(expr, id_array)| { replace_common_expr(expr, id_array, expr_stats, common_exprs) + .data() }) .collect::>>() }) @@ -158,9 +164,9 @@ impl CommonSubexprEliminate { /// common sub-expressions that were used fn rewrite_expr( &self, - exprs_list: &[&[Expr]], + exprs_list: Vec>, arrays_list: &[&[Vec<(usize, String)>]], - input: &LogicalPlan, + input: LogicalPlan, expr_stats: &ExprStats, config: &dyn OptimizerConfig, ) -> Result<(Vec>, LogicalPlan)> { @@ -173,9 +179,8 @@ impl CommonSubexprEliminate { &mut common_exprs, )?; - let mut new_input = self - .try_optimize(input, config)? - .unwrap_or_else(|| input.clone()); + let mut new_input = self.rewrite(input, config)?.data; + if !common_exprs.is_empty() { new_input = build_common_expr_project_plan(new_input, common_exprs, expr_stats)?; @@ -184,90 +189,138 @@ impl CommonSubexprEliminate { Ok((rewrite_exprs, new_input)) } - fn try_optimize_window( + fn try_optimize_proj( &self, - window: &Window, + projection: Projection, config: &dyn OptimizerConfig, - ) -> Result { - let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_stats = ExprStats::new(); - - // Get all window expressions inside the consecutive window operators. - // Consecutive window expressions may refer to same complex expression. - // If same complex expression is referred more than once by subsequent `WindowAggr`s, - // we can cache complex expression by evaluating it with a projection before the - // first WindowAggr. - // This enables us to cache complex expression "c3+c4" for following plan: - // WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - // --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - // where, it is referred once by each `WindowAggr` (total of 2) in the plan. - let mut plan = LogicalPlan::Window(window.clone()); - while let LogicalPlan::Window(window) = plan { - let Window { - input, window_expr, .. - } = window; - plan = input.as_ref().clone(); + ) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = projection; + let input = unwrap_arc(input); + self.try_unary_plan(expr, input, config)? + .map_data(|(new_expr, new_input)| { + Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema) + .map(LogicalPlan::Projection) + }) + } + fn try_optimize_sort( + &self, + sort: Sort, + config: &dyn OptimizerConfig, + ) -> Result> { + let Sort { expr, input, fetch } = sort; + let input = unwrap_arc(input); + let new_sort = self.try_unary_plan(expr, input, config)?.update_data( + |(new_expr, new_input)| { + LogicalPlan::Sort(Sort { + expr: new_expr, + input: Arc::new(new_input), + fetch, + }) + }, + ); + Ok(new_sort) + } - let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays( - &window_expr, - input_schema, - &mut expr_stats, - ExprMask::Normal, - )?; + fn try_optimize_filter( + &self, + filter: Filter, + config: &dyn OptimizerConfig, + ) -> Result> { + let Filter { + predicate, input, .. + } = filter; + let input = unwrap_arc(input); + let expr = vec![predicate]; + self.try_unary_plan(expr, input, config)? + .transform_data(|(mut new_expr, new_input)| { + assert_eq!(new_expr.len(), 1); // passed in vec![predicate] + let new_predicate = new_expr.pop().unwrap(); + Ok(Filter::remove_aliases(new_predicate)? + .update_data(|new_predicate| (new_predicate, new_input))) + })? + .map_data(|(new_predicate, new_input)| { + Filter::try_new(new_predicate, Arc::new(new_input)) + .map(LogicalPlan::Filter) + }) + } - window_exprs.push(window_expr); - arrays_per_window.push(arrays); - } + fn try_optimize_window( + &self, + window: Window, + config: &dyn OptimizerConfig, + ) -> Result> { + // collect all window expressions from any number of LogicalPlanWindow + let ConsecutiveWindowExprs { + window_exprs, + arrays_per_window, + expr_stats, + plan, + } = ConsecutiveWindowExprs::try_new(window)?; - let mut window_exprs = window_exprs - .iter() - .map(|expr| expr.as_slice()) - .collect::>(); let arrays_per_window = arrays_per_window .iter() .map(|arrays| arrays.as_slice()) .collect::>(); + // save the original names + let name_preserver = NamePreserver::new(&plan); + let mut saved_names = window_exprs + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>>() + }) + .collect::>>()?; + assert_eq!(window_exprs.len(), arrays_per_window.len()); + let num_window_exprs = window_exprs.len(); let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, + window_exprs, &arrays_per_window, - &plan, + plan, &expr_stats, config, )?; - assert_eq!(window_exprs.len(), new_expr.len()); - // Construct consecutive window operator, with their corresponding new window expressions. - plan = new_input; - while let Some(new_window_expr) = new_expr.pop() { - // Since `new_expr` and `window_exprs` length are same. We can safely `.unwrap` here. - let orig_window_expr = window_exprs.pop().unwrap(); - assert_eq!(new_window_expr.len(), orig_window_expr.len()); + let mut plan = new_input; - // Rename new re-written window expressions with original name (by giving alias) - // Otherwise we may receive schema error, in subsequent operators. + // Construct consecutive window operator, with their corresponding new + // window expressions. + // + // Note this iterates over, `new_expr` and `saved_names` which are the + // same length, in reverse order + assert_eq!(num_window_exprs, new_expr.len()); + assert_eq!(num_window_exprs, saved_names.len()); + while let (Some(new_window_expr), Some(saved_names)) = + (new_expr.pop(), saved_names.pop()) + { + assert_eq!(new_window_expr.len(), saved_names.len()); + + // Rename re-written window expressions with original name, to + // preserve the output schema let new_window_expr = new_window_expr .into_iter() - .zip(orig_window_expr.iter()) - .map(|(new_window_expr, window_expr)| { - let original_name = window_expr.name_for_alias()?; - new_window_expr.alias_if_changed(original_name) - }) + .zip(saved_names.into_iter()) + .map(|(new_window_expr, saved_name)| saved_name.restore(new_window_expr)) .collect::>>()?; plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); } - Ok(plan) + Ok(Transformed::yes(plan)) } fn try_optimize_aggregate( &self, - aggregate: &Aggregate, + aggregate: Aggregate, config: &dyn OptimizerConfig, - ) -> Result { + ) -> Result> { let Aggregate { group_expr, aggr_expr, @@ -279,18 +332,25 @@ impl CommonSubexprEliminate { // rewrite inputs let input_schema = Arc::clone(input.schema()); let group_arrays = to_arrays( - group_expr, + &group_expr, Arc::clone(&input_schema), &mut expr_stats, ExprMask::Normal, )?; let aggr_arrays = - to_arrays(aggr_expr, input_schema, &mut expr_stats, ExprMask::Normal)?; + to_arrays(&aggr_expr, input_schema, &mut expr_stats, ExprMask::Normal)?; + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>>()?; + + // rewrite both group exprs and aggr_expr let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], + vec![group_expr, aggr_expr], &[&group_arrays, &aggr_arrays], - input, + unwrap_arc(input), &expr_stats, config, )?; @@ -303,13 +363,13 @@ impl CommonSubexprEliminate { let new_input_schema = Arc::clone(new_input.schema()); let aggr_arrays = to_arrays( &new_aggr_expr, - new_input_schema.clone(), + Arc::clone(&new_input_schema), &mut expr_stats, ExprMask::NormalAndAggregates, )?; let mut common_exprs = IndexMap::new(); let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], + vec![new_aggr_expr.clone()], &[&aggr_arrays], &expr_stats, &mut common_exprs, @@ -319,98 +379,201 @@ impl CommonSubexprEliminate { if common_exprs.is_empty() { // Alias aggregation expressions if they have changed let new_aggr_expr = new_aggr_expr - .iter() - .zip(aggr_expr.iter()) - .map(|(new_expr, old_expr)| { - new_expr.clone().alias_if_changed(old_expr.display_name()?) - }) + .into_iter() + .zip(saved_names.into_iter()) + .map(|(new_expr, saved_name)| saved_name.restore(new_expr)) .collect::>>()?; // Since group_epxr changes, schema changes also. Use try_new method. - Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) - } else { - let mut agg_exprs = common_exprs - .into_iter() - .map(|(expr_id, expr)| { - // todo: check `nullable` - expr.alias(expr_id) - }) - .collect::>(); + return Aggregate::try_new( + Arc::new(new_input), + new_group_expr, + new_aggr_expr, + ) + .map(LogicalPlan::Aggregate) + .map(Transformed::yes); + } + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr_id, expr)| { + // todo: check `nullable` + expr.alias(expr_id) + }) + .collect::>(); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &new_input_schema, &mut proj_exprs)? - } - for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { - agg_exprs.push(expr.alias(&name)); - proj_exprs.push(Expr::Column(Column::from_name(name))); - } else { - let id = expr_identifier(&expr_rewritten, "".to_string()); - let (qualifier, field) = - expr_rewritten.to_field(&new_input_schema)?; - let out_name = qualified_name(qualifier.as_ref(), field.name()); - - agg_exprs.push(expr_rewritten.alias(&id)); - proj_exprs - .push(Expr::Column(Column::from_name(id)).alias(out_name)); - } + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &new_input_schema, &mut proj_exprs)? + } + for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { + agg_exprs.push(expr.alias(&name)); + proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - proj_exprs.push(expr_rewritten); + let id = expr_identifier(&expr_rewritten, "".to_string()); + let (qualifier, field) = + expr_rewritten.to_field(&new_input_schema)?; + let out_name = qualified_name(qualifier.as_ref(), field.name()); + + agg_exprs.push(expr_rewritten.alias(&id)); + proj_exprs.push(Expr::Column(Column::from_name(id)).alias(out_name)); } + } else { + proj_exprs.push(expr_rewritten); } + } - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(new_input), - new_group_expr, - agg_exprs, - )?); + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(new_input), + new_group_expr, + agg_exprs, + )?); - Ok(LogicalPlan::Projection(Projection::try_new( - proj_exprs, - Arc::new(agg), - )?)) - } + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) } + /// Rewrites the expr list and input to remove common subexpressions + /// + /// # Parameters + /// + /// * `exprs`: List of expressions in the node + /// * `input`: input plan (that produces the columns referred to in `exprs`) + /// + /// # Return value + /// + /// Returns `(rewritten_exprs, new_input)`. `new_input` is either: + /// + /// 1. The original `input` of no common subexpressions were extracted + /// 2. A newly added projection on top of the original input + /// that computes the common subexpressions fn try_unary_plan( &self, - plan: &LogicalPlan, + expr: Vec, + input: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result { - let expr = plan.expressions(); - let inputs = plan.inputs(); - let input = inputs[0]; - let input_schema = Arc::clone(input.schema()); + ) -> Result, LogicalPlan)>> { let mut expr_stats = ExprStats::new(); - - // Visit expr list and build expr identifier to occuring count map (`expr_stats`). - let arrays = to_arrays(&expr, input_schema, &mut expr_stats, ExprMask::Normal)?; + let arrays = to_arrays( + &expr, + Arc::clone(input.schema()), + &mut expr_stats, + ExprMask::Normal, + )?; let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], &[&arrays], input, &expr_stats, config)?; + self.rewrite_expr(vec![expr], &[&arrays], input, &expr_stats, config)?; + assert_eq!(new_expr.len(), 1); + let result = (new_expr.pop().unwrap(), new_input); + Ok(Transformed::yes(result)) + } +} + +/// Get all window expressions inside the consecutive window operators. +/// +/// Returns the window expressions, and the input to the deepest child +/// LogicalPlan. +/// +/// For example, if the input widnow looks like +/// +/// ```text +/// LogicalPlan::Window(exprs=[a, b, c]) +/// LogicalPlan::Window(exprs=[d]) +/// InputPlan +/// ``` +/// +/// Returns: +/// * `window_exprs`: `[a, b, c, d]` +/// * InputPlan +/// +/// Consecutive window expressions may refer to same complex expression. +/// +/// If same complex expression is referred more than once by subsequent +/// `WindowAggr`s, we can cache complex expression by evaluating it with a +/// projection before the first WindowAggr. +/// +/// This enables us to cache complex expression "c3+c4" for following plan: +/// +/// ```text +/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +/// ``` +/// +/// where, it is referred once by each `WindowAggr` (total of 2) in the plan. +struct ConsecutiveWindowExprs { + window_exprs: Vec>, + /// result of calling `to_arrays` on each set of window exprs + arrays_per_window: Vec>>, + expr_stats: ExprStats, + /// input plan to the window + plan: LogicalPlan, +} + +impl ConsecutiveWindowExprs { + fn try_new(window: Window) -> Result { + let mut window_exprs = vec![]; + let mut arrays_per_window = vec![]; + let mut expr_stats = ExprStats::new(); + + let mut plan = LogicalPlan::Window(window); + while let LogicalPlan::Window(Window { + input, window_expr, .. + }) = plan + { + let input_schema = Arc::clone(input.schema()); + plan = unwrap_arc(input); - plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) + let arrays = to_arrays( + &window_expr, + input_schema, + &mut expr_stats, + ExprMask::Normal, + )?; + + window_exprs.push(window_expr); + arrays_per_window.push(arrays); + } + + Ok(Self { + window_exprs, + arrays_per_window, + expr_stats, + plan, + }) } } impl OptimizerRule for CommonSubexprEliminate { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called CommonSubexprEliminate::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let original_schema = Arc::clone(plan.schema()); + let optimized_plan = match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?), - LogicalPlan::Window(window) => { - Some(self.try_optimize_window(window, config)?) - } - LogicalPlan::Aggregate(aggregate) => { - Some(self.try_optimize_aggregate(aggregate, config)?) - } + LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?, + LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?, + LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?, + LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, + LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) @@ -433,21 +596,19 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { - // apply the optimization to all inputs of the plan - utils::optimize_children(self, plan, config)? + // ApplyOrder::TopDown handles recursion + Transformed::no(plan) } }; - let original_schema = plan.schema(); - match optimized_plan { - Some(optimized_plan) if optimized_plan.schema() != original_schema => { - // add an additional projection if the output schema changed. - Ok(Some(build_recover_project_plan( - original_schema, - optimized_plan, - )?)) - } - plan => Ok(plan), + // If we rewrote the plan, ensure the schema stays the same + if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema + { + optimized_plan.map_data(|optimized_plan| { + build_recover_project_plan(&original_schema, optimized_plan) + }) + } else { + Ok(optimized_plan) } } @@ -472,16 +633,23 @@ impl CommonSubexprEliminate { fn pop_expr(new_expr: &mut Vec>) -> Result> { new_expr .pop() - .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) + .ok_or_else(|| internal_datafusion_err!("Failed to pop expression")) } +/// Returns the identifier list for each element in `exprs` +/// +/// Returns and array with 1 element for each input expr in `exprs` +/// +/// Each element is itself the result of [`expr_to_identifier`] for that expr +/// (e.g. the identifiers for each node in the tree) fn to_arrays( - expr: &[Expr], + exprs: &[Expr], input_schema: DFSchemaRef, expr_stats: &mut ExprStats, expr_mask: ExprMask, ) -> Result>> { - expr.iter() + exprs + .iter() .map(|e| { let mut id_array = vec![]; expr_to_identifier( @@ -494,7 +662,7 @@ fn to_arrays( Ok(id_array) }) - .collect::>>() + .collect() } /// Build the "intermediate" projection plan that evaluates the extracted common @@ -532,10 +700,7 @@ fn build_common_expr_project_plan( } } - Ok(LogicalPlan::Projection(Projection::try_new( - project_exprs, - Arc::new(input), - )?)) + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) } /// Build the projection plan to eliminate unnecessary columns produced by @@ -548,10 +713,7 @@ fn build_recover_project_plan( input: LogicalPlan, ) -> Result { let col_exprs = schema.iter().map(Expr::from).collect(); - Ok(LogicalPlan::Projection(Projection::try_new( - col_exprs, - Arc::new(input), - )?)) + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) } fn extract_expressions( @@ -823,14 +985,13 @@ fn replace_common_expr( id_array: &IdArray, expr_stats: &ExprStats, common_exprs: &mut CommonExprs, -) -> Result { +) -> Result> { expr.rewrite(&mut CommonSubexprRewriter { expr_stats, id_array, common_exprs, down_index: 0, }) - .data() } #[cfg(test)] @@ -853,12 +1014,11 @@ mod test { use super::*; - fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) { + fn assert_optimized_plan_eq(expected: &str, plan: LogicalPlan) { let optimizer = CommonSubexprEliminate {}; - let optimized_plan = optimizer - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let optimized_plan = optimizer.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(optimized_plan.transformed, "failed to optimize plan"); + let optimized_plan = optimized_plan.data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(expected, formatted_plan); } @@ -957,7 +1117,7 @@ mod test { \n Projection: test.a * (Int32(1) - test.b) AS {test.a * (Int32(1) - test.b)|{Int32(1) - test.b|{test.b}|{Int32(1)}}|{test.a}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1010,7 +1170,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, my_agg(test.c) AS {my_agg(test.c)}]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1029,7 +1189,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1044,7 +1204,7 @@ mod test { let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1061,7 +1221,7 @@ mod test { \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -1083,7 +1243,7 @@ mod test { \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1110,7 +1270,7 @@ mod test { \n Projection: UInt32(1) + table.test.col.a AS {UInt32(1) + table.test.col.a|{table.test.col.a}|{UInt32(1)}}, table.test.col.a\ \n TableScan: table.test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1130,7 +1290,7 @@ mod test { \n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1146,7 +1306,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1164,7 +1324,7 @@ mod test { \n Projection: Int32(1) + test.a, test.a\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1264,10 +1424,9 @@ mod test { .build() .unwrap(); let rule = CommonSubexprEliminate {}; - let optimized_plan = rule - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() - .unwrap(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(optimized_plan.transformed); + let optimized_plan = optimized_plan.data; let schema = optimized_plan.schema(); let fields_with_datatypes: Vec<_> = schema @@ -1306,7 +1465,7 @@ mod test { \n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) }