diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index eea5fc1127ce..44c55e1f880b 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -175,7 +175,7 @@ impl DataFrame { /// Consume the DataFrame and produce a physical plan pub async fn create_physical_plan(self) -> Result> { - self.session_state.create_physical_plan(&self.plan).await + self.session_state.create_physical_plan(self.plan).await } /// Filter the DataFrame by column. Returns a new DataFrame only containing the @@ -989,7 +989,7 @@ impl DataFrame { /// [`Self::into_optimized_plan`] for more details. pub fn into_optimized_plan(self) -> Result { // Optimize the plan first for better UX - self.session_state.optimize(&self.plan) + self.session_state.optimize(self.plan) } /// Converts this [`DataFrame`] into a [`TableProvider`] that can be registered @@ -1466,7 +1466,7 @@ impl TableProvider for DataFrameTableProvider { expr = expr.limit(0, Some(l))? } let plan = expr.build()?; - state.create_physical_plan(&plan).await + state.create_physical_plan(plan).await } } diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 85fb8939886c..1a985ad664af 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -141,7 +141,7 @@ impl TableProvider for ViewTable { plan = plan.limit(0, Some(limit))?; } - state.create_physical_plan(&plan.build()?).await + state.create_physical_plan(plan.build()?).await } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 116e45c8c130..67595ad7192c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -82,6 +82,7 @@ use datafusion_sql::{ use async_trait::async_trait; use chrono::{DateTime, Utc}; +use datafusion_common::tree_node::Transformed; use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; use url::Url; @@ -530,7 +531,7 @@ impl SessionContext { } = cmd; let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); - let input = self.state().optimize(&input)?; + let input = self.state().optimize(input)?; let table = self.table(&name).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -1839,13 +1840,22 @@ impl SessionState { } /// Optimizes the logical plan by applying optimizer rules. - pub fn optimize(&self, plan: &LogicalPlan) -> Result { - if let LogicalPlan::Explain(e) = plan { - let mut stringified_plans = e.stringified_plans.clone(); + pub fn optimize(&self, plan: LogicalPlan) -> Result { + if let LogicalPlan::Explain(Explain { + verbose, + plan, + mut stringified_plans, + schema, + logical_optimization_succeeded, + }) = plan + { + // TODO this could be a dummy plan + let original_plan = plan.clone(); // keep original plan in case there is an error // analyze & capture output of each rule + // TODO avoid this copy let analyzer_result = self.analyzer.execute_and_check( - e.plan.as_ref(), + &plan, self.options(), |analyzed_plan, analyzer| { let analyzer_name = analyzer.name().to_string(); @@ -1861,10 +1871,10 @@ impl SessionState { .push(StringifiedPlan::new(plan_type, err.to_string())); return Ok(LogicalPlan::Explain(Explain { - verbose: e.verbose, - plan: e.plan.clone(), + verbose, + plan, stringified_plans, - schema: e.schema.clone(), + schema, logical_optimization_succeeded: false, })); } @@ -1877,7 +1887,7 @@ impl SessionState { // optimize the child plan, capturing the output of each optimizer let optimized_plan = self.optimizer.optimize( - &analyzed_plan, + analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1885,29 +1895,33 @@ impl SessionState { stringified_plans.push(optimized_plan.to_stringified(plan_type)); }, ); + let (plan, logical_optimization_succeeded) = match optimized_plan { - Ok(plan) => (Arc::new(plan), true), + Ok(plan) => (Arc::new(plan.data), true), Err(DataFusionError::Context(optimizer_name, err)) => { + // TODO show explain error let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans .push(StringifiedPlan::new(plan_type, err.to_string())); - (e.plan.clone(), false) + (original_plan, false) } Err(e) => return Err(e), }; Ok(LogicalPlan::Explain(Explain { - verbose: e.verbose, + verbose, plan, stringified_plans, - schema: e.schema.clone(), + schema, logical_optimization_succeeded, })) } else { let analyzed_plan = self.analyzer - .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) + .execute_and_check(&plan, self.options(), |_, _| {})?; + self.optimizer + .optimize(analyzed_plan, self, |_, _| {}) + .map(|t| t.data) } } @@ -1920,7 +1934,7 @@ impl SessionState { /// DDL `CREATE TABLE` must be handled by another layer. pub async fn create_physical_plan( &self, - logical_plan: &LogicalPlan, + logical_plan: LogicalPlan, ) -> Result> { let logical_plan = self.optimize(logical_plan)?; self.query_planner diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 08fe3380061f..a83ad20ed778 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,6 +17,7 @@ //! Logical plan types +use std::cell::OnceCell; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; @@ -38,7 +39,7 @@ use crate::utils::{ split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, + build_join_schema, expr_vec_fmt, lit, BinaryExpr, BuiltInWindowFunction, CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; @@ -360,6 +361,253 @@ impl LogicalPlan { | LogicalPlan::Prepare(_) => Ok(()), } } +} + +/// writes each elemenet in the iterator using `f` +pub fn rewrite_iter_mut<'a, F>( + i: impl IntoIterator, + mut f: F, +) -> Result<()> +where + F: FnMut(Expr) -> Result, +{ + i.into_iter().try_for_each(|e| rewrite_expr(e, &mut f)) +} + +pub fn rewrite_expr<'a, F>(e: &'a mut Expr, mut f: F) -> Result<()> +where + F: FnMut(Expr) -> Result, +{ + let mut t = lit(0); + std::mem::swap(e, &mut t); + // transform + let mut t = f(t)?; + // put it back + std::mem::swap(e, &mut t); + Ok(()) +} + +impl LogicalPlan { + /// applies the closure `f` to each expression of this node, potentially + /// rewriting it in place + /// + /// If the closure returns an error, the error is returned and the expressions + /// are left in a partially modified state + pub fn rewrite_exprs(mut self, mut f: F) -> Result + where + F: FnMut(Expr) -> Result, + { + match &mut self { + LogicalPlan::Projection(Projection { expr, .. }) => { + rewrite_iter_mut(expr.iter_mut(), &mut f)?; + } + LogicalPlan::Values(Values { values, .. }) => { + rewrite_iter_mut(values.iter_mut().flatten(), &mut f)?; + } + LogicalPlan::Filter(Filter { predicate, .. }) => { + rewrite_expr(predicate, &mut f)? + } + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::Hash(expr, _) => rewrite_iter_mut(expr.iter_mut(), &mut f)?, + Partitioning::DistributeBy(expr) => { + rewrite_iter_mut(expr.iter_mut(), &mut f)? + } + Partitioning::RoundRobinBatch(_) => {} + }, + LogicalPlan::Window(Window { window_expr, .. }) => { + rewrite_iter_mut(window_expr.iter_mut(), &mut f)?; + } + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => rewrite_iter_mut( + group_expr.iter_mut().chain(aggr_expr.iter_mut()), + &mut f, + )?, + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { on, filter, .. }) => { + // don't look at the equijoin expressions as a whole + rewrite_iter_mut( + on.iter_mut().flat_map(|(e1, e2)| { + std::iter::once(e1).chain(std::iter::once(e2)) + }), + &mut f, + )?; + + if let Some(filter) = filter.as_mut() { + rewrite_expr(filter, &mut f)?; + } + } + LogicalPlan::Sort(Sort { expr, .. }) => { + rewrite_iter_mut(expr.iter_mut(), &mut f)? + } + LogicalPlan::Extension(extension) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + //extension.node.expressions().iter().try_for_each(f) + todo!(); + } + LogicalPlan::TableScan(TableScan { filters, .. }) => { + rewrite_iter_mut(filters.iter_mut(), &mut f)?; + } + LogicalPlan::Unnest(Unnest { column, .. }) => { + //f(&Expr::Column(column.clone())) + todo!(); + } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => rewrite_iter_mut( + on_expr + .iter_mut() + .chain(select_expr.iter_mut()) + .chain(sort_expr.iter_mut().flat_map(|x| x.iter_mut())), + &mut f, + )?, + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => {} + } + + Ok(self) + } +} + +const PLACEHOLDER: OnceCell> = OnceCell::new(); + +// applies f to rewrite the logical plan, replacing `node` +// +// ideally we would remove the Arc nonsense entirely from LogicalPlan and have it own its inputs +// however, for now do a horrible hack +// +// On rewrite the existing plan is destroyed +fn rewrite_arc(node: &mut Arc, f: &mut F) -> Result<()> +where + F: FnMut(LogicalPlan) -> Result, +{ + let mut new_node = PLACEHOLDER + .get_or_init(|| { + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::empty()), + })) + }) + .clone(); + + // take the old value out of the Arc + std::mem::swap(node, &mut new_node); + + let new_node = match Arc::try_unwrap(new_node) { + Ok(node) => { + //println!("Unwrapped arc yay"); + node + } + Err(node) => { + //println!("Failed to unwrap arc boo"); + node.as_ref().clone() + } + }; + // do the actual transform + let mut new_node = f(new_node).map(Arc::new)?; + // put the new value back into the Arc + std::mem::swap(node, &mut new_node); + + Ok(()) +} + +impl LogicalPlan { + /// applies the closure `f` to each input of this node, replacing the existing inputs + /// with the result of the closure. + pub fn rewrite_inputs(mut self, mut f: F) -> Result + where + F: FnMut(LogicalPlan) -> Result, + { + match &mut self { + LogicalPlan::Projection(Projection { input, .. }) => { + rewrite_arc(input, &mut f)? + } + LogicalPlan::Filter(Filter { input, .. }) => rewrite_arc(input, &mut f)?, + LogicalPlan::Repartition(Repartition { input, .. }) => { + rewrite_arc(input, &mut f)? + } + LogicalPlan::Window(Window { input, .. }) => rewrite_arc(input, &mut f)?, + LogicalPlan::Aggregate(Aggregate { input, .. }) => { + rewrite_arc(input, &mut f)? + } + LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc(input, &mut f)?, + LogicalPlan::Join(Join { left, right, .. }) => { + rewrite_arc(left, &mut f)?; + rewrite_arc(right, &mut f)?; + } + LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + rewrite_arc(left, &mut f)?; + rewrite_arc(right, &mut f)?; + } + LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc(input, &mut f)?, + LogicalPlan::Subquery(Subquery { subquery, .. }) => { + rewrite_arc(subquery, &mut f)? + } + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { + rewrite_arc(input, &mut f)? + } + LogicalPlan::Extension(extension) => todo!(), + LogicalPlan::Union(Union { inputs, .. }) => { + inputs + .iter_mut() + .try_for_each(|input| rewrite_arc(input, &mut f))?; + } + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => rewrite_arc(input, &mut f)?, + LogicalPlan::Explain(explain) => rewrite_arc(&mut explain.plan, &mut f)?, + LogicalPlan::Analyze(analyze) => rewrite_arc(&mut analyze.input, &mut f)?, + LogicalPlan::Dml(write) => rewrite_arc(&mut write.input, &mut f)?, + LogicalPlan::Copy(copy) => rewrite_arc(&mut copy.input, &mut f)?, + LogicalPlan::Ddl(ddl) => { + todo!(); + } + LogicalPlan::Unnest(Unnest { input, .. }) => rewrite_arc(input, &mut f)?, + LogicalPlan::Prepare(Prepare { input, .. }) => rewrite_arc(input, &mut f)?, + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => { + rewrite_arc(static_term, &mut f)?; + rewrite_arc(recursive_term, &mut f)?; + } + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) => {} + } + + Ok(self) + } /// returns all inputs of this `LogicalPlan` node. Does not /// include inputs to inputs, or subqueries. diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index fe63766fc265..65cde9bda06d 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -48,10 +48,11 @@ use crate::utils::log_plan; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use chrono::{DateTime, Utc}; +use datafusion_common::tree_node::Transformed; use log::{debug, warn}; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which @@ -85,6 +86,20 @@ pub trait OptimizerRule { fn apply_order(&self) -> Option { None } + + /// does this rule support rewriting owned plans (to reduce copying)? + fn supports_owned(&self) -> bool { + false + } + + /// if supports_owned returns true, calls try_optimize_owned + fn try_optimize_owned( + &self, + _plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + not_impl_err!("try_optimized_owned is not implemented for this rule") + } } /// Options to control the DataFusion Optimizer. @@ -279,10 +294,10 @@ impl Optimizer { /// invoking observer function after each call pub fn optimize( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, mut observer: F, - ) -> Result + ) -> Result> where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { @@ -291,6 +306,7 @@ impl Optimizer { let start_time = Instant::now(); + let mut transformed = false; let mut previous_plans = HashSet::with_capacity(16); previous_plans.insert(LogicalPlanSignature::new(&new_plan)); @@ -299,21 +315,31 @@ impl Optimizer { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); for rule in &self.rules { + // if we are skipping failed rules, we need to keep a copy of the plan in case the optimizer fails + let prev_plan = if options.optimizer.skip_failed_rules { + Some(new_plan.clone()) + } else { + None + }; + + let orig_schema = plan.schema().clone(); + let result = - self.optimize_recursively(rule, &new_plan, config) + self.optimize_recursively(rule, new_plan, config) .and_then(|plan| { - if let Some(plan) = &plan { - assert_schema_is_the_same(rule.name(), plan, &new_plan)?; - } + assert_has_schema(rule.name(), &orig_schema, &plan.data)?; Ok(plan) }); - match result { - Ok(Some(plan)) => { - new_plan = plan; + + match (result, prev_plan) { + (Ok(t), _) if t.transformed => { + transformed = true; + new_plan = t.data; observer(&new_plan, rule.as_ref()); log_plan(rule.name(), &new_plan); } - Ok(None) => { + (Ok(t), _) => { + new_plan = t.data; observer(&new_plan, rule.as_ref()); debug!( "Plan unchanged by optimizer rule '{}' (pass {})", @@ -321,22 +347,22 @@ impl Optimizer { i ); } - Err(e) => { - if options.optimizer.skip_failed_rules { - // Note to future readers: if you see this warning it signals a - // bug in the DataFusion optimizer. Please consider filing a ticket - // https://github.com/apache/arrow-datafusion - warn!( + (Err(e), Some(prev_plan)) => { + // Note to future readers: if you see this warning it signals a + // bug in the DataFusion optimizer. Please consider filing a ticket + // https://github.com/apache/arrow-datafusion + warn!( "Skipping optimizer rule '{}' due to unexpected error: {}", rule.name(), e ); - } else { - return Err(DataFusionError::Context( - format!("Optimizer rule '{}' failed", rule.name(),), - Box::new(e), - )); - } + new_plan = prev_plan; + } + (Err(e), None) => { + return Err(DataFusionError::Context( + format!("Optimizer rule '{}' failed", rule.name(),), + Box::new(e), + )); } } } @@ -354,45 +380,54 @@ impl Optimizer { } log_plan("Final optimized plan", &new_plan); debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); - Ok(new_plan) + Ok(if transformed { + Transformed::yes(new_plan) + } else { + Transformed::no(new_plan) + }) } fn optimize_node( &self, rule: &Arc, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { - // TODO: future feature: We can do Batch optimize - rule.try_optimize(plan, config) + ) -> Result> { + if rule.supports_owned() { + rule.try_optimize_owned(plan, config) + } else { + // TODO: future feature: We can do Batch optimize + rule.try_optimize(&plan, config).map(|opt| { + if let Some(opt_plan) = opt { + Transformed::yes(opt_plan) + } else { + // return original plan + Transformed::no(plan) + } + }) + } } fn optimize_inputs( &self, rule: &Arc, - plan: &LogicalPlan, + mut plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { - let inputs = plan.inputs(); - let result = inputs - .iter() - .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config)) - .collect::>>()?; - if result.is_empty() || result.iter().all(|o| o.is_none()) { - return Ok(None); - } - - let new_inputs = result - .into_iter() - .zip(inputs) - .map(|(new_plan, old_plan)| match new_plan { - Some(plan) => plan, - None => old_plan.clone(), - }) - .collect(); - - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) + ) -> Result> { + let mut transformed = false; + let plan = plan.rewrite_inputs(|child| { + let t = self.optimize_recursively(rule, child, config)?; + if t.transformed { + transformed = true; + } + Ok(t.data) + })?; + + Ok(if transformed { + Transformed::yes(plan) + } else { + Transformed::no(plan) + }) } /// Use a rule to optimize the whole plan. @@ -400,53 +435,53 @@ impl Optimizer { pub fn optimize_recursively( &self, rule: &Arc, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match rule.apply_order() { Some(order) => match order { ApplyOrder::TopDown => { - let optimize_self_opt = self.optimize_node(rule, plan, config)?; - let optimize_inputs_opt = match &optimize_self_opt { - Some(optimized_plan) => { - self.optimize_inputs(rule, optimized_plan, config)? - } - _ => self.optimize_inputs(rule, plan, config)?, - }; - Ok(optimize_inputs_opt.or(optimize_self_opt)) + let optimized_plan = self.optimize_node(rule, plan, config)?; + let transformed = optimized_plan.transformed; + + // TODO make a nicer 'and_then' type API on Transformed + let optimized_plan = + self.optimize_inputs(rule, optimized_plan.data, config)?; + Ok(if transformed || optimized_plan.transformed { + Transformed::yes(optimized_plan.data) + } else { + Transformed::no(optimized_plan.data) + }) } ApplyOrder::BottomUp => { - let optimize_inputs_opt = self.optimize_inputs(rule, plan, config)?; - let optimize_self_opt = match &optimize_inputs_opt { - Some(optimized_plan) => { - self.optimize_node(rule, optimized_plan, config)? - } - _ => self.optimize_node(rule, plan, config)?, - }; - Ok(optimize_self_opt.or(optimize_inputs_opt)) + let optimized_plan = self.optimize_inputs(rule, plan, config)?; + let transformed = optimized_plan.transformed; + let optimized_plan = + self.optimize_node(rule, optimized_plan.data, config)?; + + Ok(if transformed || optimized_plan.transformed { + Transformed::yes(optimized_plan.data) + } else { + Transformed::no(optimized_plan.data) + }) } }, - _ => rule.try_optimize(plan, config), + _ => self.optimize_node(rule, plan, config), } } } -/// Returns an error if plans have different schemas. -/// -/// It ignores metadata and nullability. -pub(crate) fn assert_schema_is_the_same( +pub(crate) fn assert_has_schema( rule_name: &str, - prev_plan: &LogicalPlan, + schema: &DFSchema, new_plan: &LogicalPlan, ) -> Result<()> { - let equivalent = new_plan - .schema() - .equivalent_names_and_types(prev_plan.schema()); + let equivalent = new_plan.schema().equivalent_names_and_types(schema); if !equivalent { let e = DataFusionError::Internal(format!( "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_plan.schema(), + schema, new_plan.schema() )); Err(DataFusionError::Context( @@ -458,6 +493,17 @@ pub(crate) fn assert_schema_is_the_same( } } +/// Returns an error if plans have different schemas. +/// +/// It ignores metadata and nullability. +pub(crate) fn assert_schema_is_the_same( + rule_name: &str, + prev_plan: &LogicalPlan, + new_plan: &LogicalPlan, +) -> Result<()> { + assert_has_schema(rule_name, prev_plan.schema(), new_plan) +} + #[cfg(test)] mod tests { use std::sync::{Arc, Mutex}; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 70b163acc208..31c2d4af6b3c 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -19,7 +19,8 @@ use std::sync::Arc; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::{not_impl_err, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; @@ -55,20 +56,33 @@ impl OptimizerRule for SimplifyExpressions { plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + return not_impl_err!("Should use optimized owned"); + } + + fn supports_owned(&self) -> bool { + true + } + + /// if supports_owned returns true, calls try_optimize_owned + fn try_optimize_owned( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { let mut execution_props = ExecutionProps::new(); execution_props.query_execution_start_time = config.query_execution_start_time(); - Ok(Some(Self::optimize_internal(plan, &execution_props)?)) + Self::optimize_internal(plan, &execution_props) } } impl SimplifyExpressions { fn optimize_internal( - plan: &LogicalPlan, + plan: LogicalPlan, execution_props: &ExecutionProps, - ) -> Result { + ) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(plan.inputs())) - } else if let LogicalPlan::TableScan(scan) = plan { + } else if let LogicalPlan::TableScan(scan) = &plan { // When predicates are pushed into a table scan, there is no input // schema to resolve predicates against, so it must be handled specially // @@ -88,11 +102,15 @@ impl SimplifyExpressions { }; let info = SimplifyContext::new(execution_props).with_schema(schema); - let new_inputs = plan - .inputs() - .iter() - .map(|input| Self::optimize_internal(input, execution_props)) - .collect::>>()?; + // rewrite all inputs + let mut transformed = false; + let plan = plan.rewrite_inputs(&mut |plan| { + let t = Self::optimize_internal(plan, execution_props)?; + if t.transformed { + transformed = true; + } + Ok(t.data) + })?; let simplifier = ExprSimplifier::new(info); @@ -109,18 +127,35 @@ impl SimplifyExpressions { simplifier }; - let exprs = plan - .expressions() - .into_iter() - .map(|e| { - // TODO: unify with `rewrite_preserving_name` - let original_name = e.name_for_alias()?; - let new_e = simplifier.simplify(e)?; - new_e.alias_if_changed(original_name) + let is_filter = matches!(plan, LogicalPlan::Filter(_)); + + let plan = plan.rewrite_exprs(|e| { + // no aliasing for filters + if is_filter { + return simplifier.simplify(e); + } + + // TODO: unify with `rewrite_preserving_name` + // todo track if e was rewritten + let original_name = e.name_for_alias()?; + let new_e = simplifier.simplify(e)?; + + // inline new_e.alias_if_changed(original_name) + // to figure out if the expression was transformed + let new_name = new_e.name_for_alias()?; + Ok(if new_name == original_name { + new_e + } else { + transformed = true; + new_e.alias(original_name) }) - .collect::>>()?; + })?; - plan.with_new_exprs(exprs, new_inputs) + Ok(if transformed { + Transformed::yes(plan) + } else { + Transformed::no(plan) + }) } } diff --git a/datafusion/sqllogictest/test_files/aal.slt b/datafusion/sqllogictest/test_files/aal.slt new file mode 100644 index 000000000000..14ca43471eaf --- /dev/null +++ b/datafusion/sqllogictest/test_files/aal.slt @@ -0,0 +1,19 @@ + +statement ok +create table t as values (1), (2); + +query I +select column1 + column1 from t; +---- +2 +4 + +query TT +explain select column1 + column1, 2+3 from t; +---- +logical_plan +Projection: t.column1 + t.column1, Int64(5) AS Int64(2) + Int64(3) +--TableScan: t projection=[column1] +physical_plan +ProjectionExec: expr=[column1@0 + column1@0 as t.column1 + t.column1, 5 as Int64(2) + Int64(3)] +--MemoryExec: partitions=1, partition_sizes=[1]