diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index c04bdb37187..b1219a95881 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -28,9 +28,12 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use crate::datasource::TableProvider; use crate::sql::parser::FileType; -use super::display::{GraphvizVisitor, IndentVisitor}; use super::expr::Expr; use super::extension::UserDefinedLogicalNode; +use super::{ + col, + display::{GraphvizVisitor, IndentVisitor}, +}; use crate::logical_plan::dfschema::DFSchemaRef; /// Join type @@ -238,6 +241,63 @@ impl LogicalPlan { Field::new("plan", DataType::Utf8, false), ])) } + + /// returns all expressions (non-recursively) in the current + /// logical plan node. This does not include expressions in any + /// children + pub fn expressions(self: &LogicalPlan) -> Vec { + match self { + LogicalPlan::Projection { expr, .. } => expr.clone(), + LogicalPlan::Filter { predicate, .. } => vec![predicate.clone()], + LogicalPlan::Repartition { + partitioning_scheme, + .. + } => match partitioning_scheme { + Partitioning::Hash(expr, _) => expr.clone(), + _ => vec![], + }, + LogicalPlan::Aggregate { + group_expr, + aggr_expr, + .. + } => { + let mut result = group_expr.clone(); + result.extend(aggr_expr.clone()); + result + } + LogicalPlan::Join { on, .. } => { + on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect() + } + LogicalPlan::Sort { expr, .. } => expr.clone(), + LogicalPlan::Extension { node } => node.expressions(), + // plans without expressions + LogicalPlan::TableScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Limit { .. } + | LogicalPlan::CreateExternalTable { .. } + | LogicalPlan::Explain { .. } => vec![], + } + } + + /// returns all inputs of this `LogicalPlan` node. Does not + /// include inputs to inputs. + pub fn inputs(self: &LogicalPlan) -> Vec<&LogicalPlan> { + match self { + LogicalPlan::Projection { input, .. } => vec![input], + LogicalPlan::Filter { input, .. } => vec![input], + LogicalPlan::Repartition { input, .. } => vec![input], + LogicalPlan::Aggregate { input, .. } => vec![input], + LogicalPlan::Sort { input, .. } => vec![input], + LogicalPlan::Join { left, right, .. } => vec![left, right], + LogicalPlan::Limit { input, .. } => vec![input], + LogicalPlan::Extension { node } => node.inputs(), + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::CreateExternalTable { .. } + | LogicalPlan::Explain { .. } => vec![], + } + } } /// Logical partitioning schemes supported by the repartition operator. diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 62f5ee30c62..ab39daee276 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -73,13 +73,14 @@ impl OptimizerRule for ConstantFolding { | LogicalPlan::Limit { .. } | LogicalPlan::Join { .. } => { // apply the optimization to all inputs of the plan - let inputs = utils::inputs(plan); + let inputs = plan.inputs(); let new_inputs = inputs .iter() .map(|plan| self.optimize(plan)) .collect::>>()?; - let expr = utils::expressions(plan) + let expr = plan + .expressions() .into_iter() .map(|e| e.rewrite(&mut rewriter)) .collect::>>()?; diff --git a/rust/datafusion/src/optimizer/filter_push_down.rs b/rust/datafusion/src/optimizer/filter_push_down.rs index 60ab97b0570..41596c5beeb 100644 --- a/rust/datafusion/src/optimizer/filter_push_down.rs +++ b/rust/datafusion/src/optimizer/filter_push_down.rs @@ -143,12 +143,13 @@ fn get_join_predicates<'a>( /// Optimizes the plan fn push_down(state: &State, plan: &LogicalPlan) -> Result { - let new_inputs = utils::inputs(&plan) + let new_inputs = plan + .inputs() .iter() .map(|input| optimize(input, state.clone())) .collect::>>()?; - let expr = utils::expressions(&plan); + let expr = plan.expressions(); utils::from_plan(&plan, &expr, &new_inputs) } @@ -326,7 +327,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { let right = optimize(right, right_state)?; // create a new Join with the new `left` and `right` - let expr = utils::expressions(&plan); + let expr = plan.expressions(); let plan = utils::from_plan(&plan, &expr, &[left, right])?; if keep.0.is_empty() { diff --git a/rust/datafusion/src/optimizer/hash_build_probe_order.rs b/rust/datafusion/src/optimizer/hash_build_probe_order.rs index e6ad905e73d..a47195aba2a 100644 --- a/rust/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/rust/datafusion/src/optimizer/hash_build_probe_order.rs @@ -146,10 +146,10 @@ impl OptimizerRule for HashBuildProbeOrder { | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Explain { .. } | LogicalPlan::Extension { .. } => { - let expr = utils::expressions(plan); + let expr = plan.expressions(); // apply the optimization to all inputs of the plan - let inputs = utils::inputs(plan); + let inputs = plan.inputs(); let new_inputs = inputs .iter() .map(|plan| self.optimize(plan)) diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 16076e960fd..115da4b3010 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -280,12 +280,12 @@ fn optimize_plan( | LogicalPlan::Sort { .. } | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Extension { .. } => { - let expr = utils::expressions(plan); + let expr = plan.expressions(); // collect all required columns by this plan utils::exprlist_to_column_names(&expr, &mut new_required_columns)?; // apply the optimization to all inputs of the plan - let inputs = utils::inputs(plan); + let inputs = plan.inputs(); let new_inputs = inputs .iter() .map(|plan| { diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 31d18b91801..c2552d77d2d 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -26,7 +26,7 @@ use crate::logical_plan::{ Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan, ToDFSchema, }; -use crate::prelude::{col, lit}; +use crate::prelude::lit; use crate::scalar::ScalarValue; use crate::{ error::{DataFusionError, Result}, @@ -144,8 +144,9 @@ pub fn optimize_children( ); } - let new_exprs = expressions(&plan); - let new_inputs = inputs(&plan) + let new_exprs = plan.expressions(); + let new_inputs = plan + .inputs() .into_iter() .map(|plan| optimizer.optimize(plan)) .collect::>>()?; @@ -153,60 +154,6 @@ pub fn optimize_children( from_plan(plan, &new_exprs, &new_inputs) } -/// returns all expressions (non-recursively) in the current logical plan node. -pub fn expressions(plan: &LogicalPlan) -> Vec { - match plan { - LogicalPlan::Projection { expr, .. } => expr.clone(), - LogicalPlan::Filter { predicate, .. } => vec![predicate.clone()], - LogicalPlan::Repartition { - partitioning_scheme, - .. - } => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.clone(), - _ => vec![], - }, - LogicalPlan::Aggregate { - group_expr, - aggr_expr, - .. - } => { - let mut result = group_expr.clone(); - result.extend(aggr_expr.clone()); - result - } - LogicalPlan::Join { on, .. } => { - on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect() - } - LogicalPlan::Sort { expr, .. } => expr.clone(), - LogicalPlan::Extension { node } => node.expressions(), - // plans without expressions - LogicalPlan::TableScan { .. } - | LogicalPlan::EmptyRelation { .. } - | LogicalPlan::Limit { .. } - | LogicalPlan::CreateExternalTable { .. } - | LogicalPlan::Explain { .. } => vec![], - } -} - -/// returns all inputs in the logical plan -pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> { - match plan { - LogicalPlan::Projection { input, .. } => vec![input], - LogicalPlan::Filter { input, .. } => vec![input], - LogicalPlan::Repartition { input, .. } => vec![input], - LogicalPlan::Aggregate { input, .. } => vec![input], - LogicalPlan::Sort { input, .. } => vec![input], - LogicalPlan::Join { left, right, .. } => vec![left, right], - LogicalPlan::Limit { input, .. } => vec![input], - LogicalPlan::Extension { node } => node.inputs(), - // plans without inputs - LogicalPlan::TableScan { .. } - | LogicalPlan::EmptyRelation { .. } - | LogicalPlan::CreateExternalTable { .. } - | LogicalPlan::Explain { .. } => vec![], - } -} - /// Returns a new logical plan based on the original one with inputs and expressions replaced pub fn from_plan( plan: &LogicalPlan,