From 2b5aa24d50766294122456a54d8e691e17c5c031 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 24 Jan 2021 15:36:17 -0800 Subject: [PATCH 01/11] ARROW-11366: [Datafusion] support boolean literal in comparison expression --- rust/datafusion/src/execution/context.rs | 29 +- .../src/optimizer/boolean_comparison.rs | 270 ++++++++++++++++++ rust/datafusion/src/optimizer/mod.rs | 1 + rust/datafusion/src/sql/planner.rs | 12 + rust/rustfmt.toml | 3 +- 5 files changed, 312 insertions(+), 3 deletions(-) create mode 100644 rust/datafusion/src/optimizer/boolean_comparison.rs diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 0870b77d0cb..d9771f977c2 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -40,6 +40,7 @@ use crate::execution::dataframe_impl::DataFrameImpl; use crate::logical_plan::{ FunctionRegistry, LogicalPlan, LogicalPlanBuilder, ToDFSchema, }; +use crate::optimizer::boolean_comparison::BooleanComparison; use crate::optimizer::filter_push_down::FilterPushDown; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; @@ -512,6 +513,7 @@ impl ExecutionConfig { concurrency: num_cpus::get(), batch_size: 32768, optimizers: vec![ + Arc::new(BooleanComparison::new()), Arc::new(ProjectionPushDown::new()), Arc::new(FilterPushDown::new()), Arc::new(HashBuildProbeOrder::new()), @@ -834,7 +836,7 @@ mod tests { projected_schema, .. } => { - assert_eq!(source.schema().fields().len(), 2); + assert_eq!(source.schema().fields().len(), 3); assert_eq!(projected_schema.fields().len(), 1); } _ => panic!("input to projection should be TableScan"), @@ -1146,6 +1148,28 @@ mod tests { Ok(()) } + #[tokio::test] + async fn boolean_literal() -> Result<()> { + let results = + execute("SELECT c1, c3 FROM test WHERE c1 > 2 AND c3 = true", 4).await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+------+", + "| c1 | c3 |", + "+----+------+", + "| 3 | true |", + "| 3 | true |", + "| 3 | true |", + "| 3 | true |", + "| 3 | true |", + "+----+------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) + } + #[tokio::test] async fn aggregate_grouped_empty() -> Result<()> { let results = @@ -1953,6 +1977,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::UInt32, false), Field::new("c2", DataType::UInt64, false), + Field::new("c3", DataType::Boolean, false), ])); // generate a partitioned file @@ -1963,7 +1988,7 @@ mod tests { // generate some data for i in 0..=10 { - let data = format!("{},{}\n", partition, i); + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); file.write_all(data.as_bytes())?; } } diff --git a/rust/datafusion/src/optimizer/boolean_comparison.rs b/rust/datafusion/src/optimizer/boolean_comparison.rs new file mode 100644 index 00000000000..cfbda474ebd --- /dev/null +++ b/rust/datafusion/src/optimizer/boolean_comparison.rs @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Boolean comparision rule rewrites redudant comparison expression involing boolean literal into +//! unary expression. + +use std::sync::Arc; + +use crate::error::Result; +use crate::logical_plan::{Expr, LogicalPlan, Operator}; +use crate::optimizer::optimizer::OptimizerRule; +use crate::optimizer::utils; +use crate::scalar::ScalarValue; + +/// Optimizer that simplifies comparison expressions involving boolean literals. +/// +/// Recursively go through all expressionss and simplify the following cases: +/// * `expr = ture` to `expr` +/// * `expr = false` to `!expr` +pub struct BooleanComparison {} + +impl BooleanComparison { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for BooleanComparison { + fn optimize(&mut self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { + predicate: optimize_expr(predicate), + input: Arc::new(self.optimize(input)?), + }), + // Rest: recurse into plan, apply optimization where possible + LogicalPlan::Projection { .. } + | LogicalPlan::Aggregate { .. } + | LogicalPlan::Limit { .. } + | LogicalPlan::Repartition { .. } + | LogicalPlan::CreateExternalTable { .. } + | LogicalPlan::Extension { .. } + | LogicalPlan::Sort { .. } + | LogicalPlan::Explain { .. } + | LogicalPlan::Join { .. } => { + let expr = utils::expressions(plan); + + // apply the optimization to all inputs of the plan + let inputs = utils::inputs(plan); + let new_inputs = inputs + .iter() + .map(|plan| self.optimize(plan)) + .collect::>>()?; + + utils::from_plan(plan, &expr, &new_inputs) + } + LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } => { + Ok(plan.clone()) + } + } + } + + fn name(&self) -> &str { + "boolean_comparison" + } +} + +/// Recursively transverses the logical plan. +fn optimize_expr(e: &Expr) -> Expr { + match e { + Expr::BinaryExpr { left, op, right } => { + let left = optimize_expr(left); + let right = optimize_expr(right); + match op { + Operator::Eq => match (&left, &right) { + (Expr::Literal(ScalarValue::Boolean(b)), _) => match b { + Some(true) => right, + Some(false) | None => Expr::Not(Box::new(right)), + }, + (_, Expr::Literal(ScalarValue::Boolean(b))) => match b { + Some(true) => left, + Some(false) | None => Expr::Not(Box::new(left)), + }, + _ => Expr::BinaryExpr { + left: Box::new(left), + op: Operator::Eq, + right: Box::new(right), + }, + }, + Operator::NotEq => match (&left, &right) { + (Expr::Literal(ScalarValue::Boolean(b)), _) => match b { + Some(false) | None => right, + Some(true) => Expr::Not(Box::new(right)), + }, + (_, Expr::Literal(ScalarValue::Boolean(b))) => match b { + Some(false) | None => left, + Some(true) => Expr::Not(Box::new(left)), + }, + _ => Expr::BinaryExpr { + left: Box::new(left), + op: Operator::NotEq, + right: Box::new(right), + }, + }, + _ => Expr::BinaryExpr { + left: Box::new(left), + op: op.clone(), + right: Box::new(right), + }, + } + } + Expr::Not(expr) => Expr::Not(Box::new(optimize_expr(&expr))), + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + if expr.is_none() { + // recurse into CASE WHEN condition expressions + Expr::Case { + expr: None, + when_then_expr: when_then_expr + .iter() + .map(|(when, then)| (Box::new(optimize_expr(when)), then.clone())) + .collect(), + else_expr: else_expr.clone(), + } + } else { + // when base expression is specified, when_then_expr conditions are literal values + // so we can just skip this case + e.clone() + } + } + Expr::Alias { .. } + | Expr::Negative { .. } + | Expr::Column { .. } + | Expr::InList { .. } + | Expr::IsNotNull { .. } + | Expr::IsNull { .. } + | Expr::Cast { .. } + | Expr::ScalarVariable { .. } + | Expr::Between { .. } + | Expr::Literal { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::Sort { .. } + | Expr::Wildcard => e.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::{col, lit, LogicalPlanBuilder}; + use crate::test::*; + + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let mut rule = BooleanComparison::new(); + let optimized_plan = rule.optimize(plan).expect("failed to optimize plan"); + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + } + + #[test] + fn simplify_eq_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .filter(col("a").eq(lit(true)))? + .filter(col("b").eq(lit(false)))? + .project(vec![col("a")])? + .build()?; + + let expected = "\ + Projection: #a\ + \n Filter: NOT #b\ + \n Filter: #a\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn simplify_not_eq_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .filter(col("a").not_eq(lit(true)))? + .filter(col("b").not_eq(lit(false)))? + .limit(1)? + .project(vec![col("a")])? + .build()?; + + let expected = "\ + Projection: #a\ + \n Limit: 1\ + \n Filter: #b\ + \n Filter: NOT #a\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn simplify_and_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .filter(col("a").not_eq(lit(true)).and(col("b").eq(lit(true))))? + .project(vec![col("a")])? + .build()?; + + let expected = "\ + Projection: #a\ + \n Filter: NOT #a And #b\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn simplify_or_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .filter(col("a").not_eq(lit(true)).or(col("b").eq(lit(false))))? + .project(vec![col("a")])? + .build()?; + + let expected = "\ + Projection: #a\ + \n Filter: NOT #a Or NOT #b\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn simplify_not_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .filter(col("a").eq(lit(false)).not())? + .project(vec![col("a")])? + .build()?; + + let expected = "\ + Projection: #a\ + \n Filter: NOT NOT #a\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } +} diff --git a/rust/datafusion/src/optimizer/mod.rs b/rust/datafusion/src/optimizer/mod.rs index 91a338eb8e6..3616593ade1 100644 --- a/rust/datafusion/src/optimizer/mod.rs +++ b/rust/datafusion/src/optimizer/mod.rs @@ -18,6 +18,7 @@ //! This module contains a query optimizer that operates against a logical plan and applies //! some simple rules to a logical plan, such as "Projection Push Down" and "Type Coercion". +pub mod boolean_comparison; pub mod filter_push_down; pub mod hash_build_probe_order; pub mod optimizer; diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 58310f50856..fc56052b29f 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -2341,6 +2341,17 @@ mod tests { quick_test(sql, expected); } + #[test] + fn boolean_literal_in_condition_expression() { + let sql = "SELECT order_id \ + FROM orders \ + WHERE delivered = false OR delivered = true"; + let expected = "Projection: #order_id\ + \n Filter: #delivered Eq Boolean(false) Or #delivered Eq Boolean(true)\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn select_typedstring() { let sql = "SELECT date '2020-12-10' AS date FROM person"; @@ -2389,6 +2400,7 @@ mod tests { Field::new("o_item_id", DataType::Utf8, false), Field::new("qty", DataType::Int32, false), Field::new("price", DataType::Float64, false), + Field::new("delivered", DataType::Boolean, false), ])), "lineitem" => Some(Schema::new(vec![ Field::new("l_item_id", DataType::UInt32, false), diff --git a/rust/rustfmt.toml b/rust/rustfmt.toml index c114c6f1b73..c49cccdd9f5 100644 --- a/rust/rustfmt.toml +++ b/rust/rustfmt.toml @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. +edition = "2018" max_width = 90 # ignore generated files # ignore = [ # "arrow/src/ipc/gen", -#] \ No newline at end of file +#] From dded02c9ce74a77db39dadba1fc11ea3f54178ee Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Tue, 26 Jan 2021 21:36:24 -0800 Subject: [PATCH 02/11] add expression tess --- .../src/optimizer/boolean_comparison.rs | 104 +++++++++++++++++- 1 file changed, 99 insertions(+), 5 deletions(-) diff --git a/rust/datafusion/src/optimizer/boolean_comparison.rs b/rust/datafusion/src/optimizer/boolean_comparison.rs index cfbda474ebd..d51834715b8 100644 --- a/rust/datafusion/src/optimizer/boolean_comparison.rs +++ b/rust/datafusion/src/optimizer/boolean_comparison.rs @@ -170,6 +170,100 @@ mod tests { use crate::logical_plan::{col, lit, LogicalPlanBuilder}; use crate::test::*; + #[test] + fn optimize_expr_eq() -> Result<()> { + assert_eq!( + optimize_expr(&Expr::BinaryExpr { + left: Box::new(lit(1)), + op: Operator::Eq, + right: Box::new(lit(true)), + }), + lit(1), + ); + + assert_eq!( + optimize_expr(&Expr::BinaryExpr { + left: Box::new(lit("a")), + op: Operator::Eq, + right: Box::new(lit(false)), + }), + lit("a").not(), + ); + + Ok(()) + } + + #[test] + fn optimize_expr_not_eq() -> Result<()> { + assert_eq!( + optimize_expr(&Expr::BinaryExpr { + left: Box::new(lit(1)), + op: Operator::NotEq, + right: Box::new(lit(true)), + }), + lit(1).not(), + ); + + assert_eq!( + optimize_expr(&Expr::BinaryExpr { + left: Box::new(lit("a")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }), + lit("a"), + ); + + Ok(()) + } + + #[test] + fn optimize_expr_not_not_eq() -> Result<()> { + assert_eq!( + optimize_expr(&Expr::Not(Box::new(Expr::BinaryExpr { + left: Box::new(lit(1)), + op: Operator::NotEq, + right: Box::new(lit(true)), + }))), + lit(1).not().not(), + ); + + assert_eq!( + optimize_expr(&Expr::Not(Box::new(Expr::BinaryExpr { + left: Box::new(lit("a")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }))), + lit("a").not(), + ); + + Ok(()) + } + + #[test] + fn optimize_expr_case_when_then_else() -> Result<()> { + assert_eq!( + optimize_expr(&Box::new(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::BinaryExpr { + left: Box::new(lit("a")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }), + Box::new(lit("ok")), + )], + else_expr: Some(Box::new(lit("not ok"))), + })), + Expr::Case { + expr: None, + when_then_expr: vec![(Box::new(lit("a")), Box::new(lit("ok")))], + else_expr: Some(Box::new(lit("not ok"))), + } + ); + + Ok(()) + } + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let mut rule = BooleanComparison::new(); let optimized_plan = rule.optimize(plan).expect("failed to optimize plan"); @@ -178,7 +272,7 @@ mod tests { } #[test] - fn simplify_eq_expr() -> Result<()> { + fn optimize_plan_eq_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("a").eq(lit(true)))? @@ -197,7 +291,7 @@ mod tests { } #[test] - fn simplify_not_eq_expr() -> Result<()> { + fn optimize_plan_not_eq_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("a").not_eq(lit(true)))? @@ -218,7 +312,7 @@ mod tests { } #[test] - fn simplify_and_expr() -> Result<()> { + fn optimize_plan_and_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("a").not_eq(lit(true)).and(col("b").eq(lit(true))))? @@ -235,7 +329,7 @@ mod tests { } #[test] - fn simplify_or_expr() -> Result<()> { + fn optimize_plan_or_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("a").not_eq(lit(true)).or(col("b").eq(lit(false))))? @@ -252,7 +346,7 @@ mod tests { } #[test] - fn simplify_not_expr() -> Result<()> { + fn optimize_plan_not_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("a").eq(lit(false)).not())? From 53fe8b09d69cbfe8cd87cb72dd2e7e63aa29d9b4 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 30 Jan 2021 15:11:02 -0800 Subject: [PATCH 03/11] rename to constant folding --- rust/datafusion/src/execution/context.rs | 4 ++-- .../{boolean_comparison.rs => constant_folding.rs} | 12 ++++++------ rust/datafusion/src/optimizer/mod.rs | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) rename rust/datafusion/src/optimizer/{boolean_comparison.rs => constant_folding.rs} (97%) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index d9771f977c2..06ef1c67ead 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -40,7 +40,7 @@ use crate::execution::dataframe_impl::DataFrameImpl; use crate::logical_plan::{ FunctionRegistry, LogicalPlan, LogicalPlanBuilder, ToDFSchema, }; -use crate::optimizer::boolean_comparison::BooleanComparison; +use crate::optimizer::constant_folding::ConstantFolding; use crate::optimizer::filter_push_down::FilterPushDown; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; @@ -513,7 +513,7 @@ impl ExecutionConfig { concurrency: num_cpus::get(), batch_size: 32768, optimizers: vec![ - Arc::new(BooleanComparison::new()), + Arc::new(ConstantFolding::new()), Arc::new(ProjectionPushDown::new()), Arc::new(FilterPushDown::new()), Arc::new(HashBuildProbeOrder::new()), diff --git a/rust/datafusion/src/optimizer/boolean_comparison.rs b/rust/datafusion/src/optimizer/constant_folding.rs similarity index 97% rename from rust/datafusion/src/optimizer/boolean_comparison.rs rename to rust/datafusion/src/optimizer/constant_folding.rs index d51834715b8..236cacd9b20 100644 --- a/rust/datafusion/src/optimizer/boolean_comparison.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -31,17 +31,17 @@ use crate::scalar::ScalarValue; /// Recursively go through all expressionss and simplify the following cases: /// * `expr = ture` to `expr` /// * `expr = false` to `!expr` -pub struct BooleanComparison {} +pub struct ConstantFolding {} -impl BooleanComparison { +impl ConstantFolding { #[allow(missing_docs)] pub fn new() -> Self { Self {} } } -impl OptimizerRule for BooleanComparison { - fn optimize(&mut self, plan: &LogicalPlan) -> Result { +impl OptimizerRule for ConstantFolding { + fn optimize(&self, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { predicate: optimize_expr(predicate), @@ -75,7 +75,7 @@ impl OptimizerRule for BooleanComparison { } fn name(&self) -> &str { - "boolean_comparison" + "constant_folding" } } @@ -265,7 +265,7 @@ mod tests { } fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let mut rule = BooleanComparison::new(); + let mut rule = ConstantFolding::new(); let optimized_plan = rule.optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); diff --git a/rust/datafusion/src/optimizer/mod.rs b/rust/datafusion/src/optimizer/mod.rs index 3616593ade1..d8dc74a64a4 100644 --- a/rust/datafusion/src/optimizer/mod.rs +++ b/rust/datafusion/src/optimizer/mod.rs @@ -18,7 +18,7 @@ //! This module contains a query optimizer that operates against a logical plan and applies //! some simple rules to a logical plan, such as "Projection Push Down" and "Type Coercion". -pub mod boolean_comparison; +pub mod constant_folding; pub mod filter_push_down; pub mod hash_build_probe_order; pub mod optimizer; From 4e4c514b31d069724d2c8bb5a620354dd00bac49 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 7 Feb 2021 00:17:30 -0800 Subject: [PATCH 04/11] ignore nonboolean expressions --- .../src/optimizer/constant_folding.rs | 432 ++++++++++++++---- 1 file changed, 344 insertions(+), 88 deletions(-) diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 236cacd9b20..03f7dcded0f 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -20,8 +20,10 @@ use std::sync::Arc; +use arrow::datatypes::DataType; + use crate::error::Result; -use crate::logical_plan::{Expr, LogicalPlan, Operator}; +use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::scalar::ScalarValue; @@ -29,8 +31,10 @@ use crate::scalar::ScalarValue; /// Optimizer that simplifies comparison expressions involving boolean literals. /// /// Recursively go through all expressionss and simplify the following cases: -/// * `expr = ture` to `expr` -/// * `expr = false` to `!expr` +/// * `expr = ture` and `expr != false` to `expr` when `expr` is of boolean type +/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type +/// * `true = true` and `false = false` to `true` +/// * `false = true` and `true = false` to `false` pub struct ConstantFolding {} impl ConstantFolding { @@ -44,7 +48,7 @@ impl OptimizerRule for ConstantFolding { fn optimize(&self, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: optimize_expr(predicate), + predicate: optimize_expr(predicate, plan.schema())?, input: Arc::new(self.optimize(input)?), }), // Rest: recurse into plan, apply optimization where possible @@ -80,21 +84,35 @@ impl OptimizerRule for ConstantFolding { } /// Recursively transverses the logical plan. -fn optimize_expr(e: &Expr) -> Expr { - match e { +fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { + Ok(match e { Expr::BinaryExpr { left, op, right } => { - let left = optimize_expr(left); - let right = optimize_expr(right); + let left = optimize_expr(left, schema)?; + let right = optimize_expr(right, schema)?; match op { Operator::Eq => match (&left, &right) { - (Expr::Literal(ScalarValue::Boolean(b)), _) => match b { - Some(true) => right, - Some(false) | None => Expr::Not(Box::new(right)), - }, - (_, Expr::Literal(ScalarValue::Boolean(b))) => match b { - Some(true) => left, - Some(false) | None => Expr::Not(Box::new(left)), - }, + ( + Expr::Literal(ScalarValue::Boolean(l)), + Expr::Literal(ScalarValue::Boolean(r)), + ) => Expr::Literal(ScalarValue::Boolean(Some( + l.unwrap_or(false) == r.unwrap_or(false), + ))), + (Expr::Literal(ScalarValue::Boolean(b)), _) + if right.get_type(schema)? == DataType::Boolean => + { + match b { + Some(true) => right, + Some(false) | None => Expr::Not(Box::new(right)), + } + } + (_, Expr::Literal(ScalarValue::Boolean(b))) + if left.get_type(schema)? == DataType::Boolean => + { + match b { + Some(true) => left, + Some(false) | None => Expr::Not(Box::new(left)), + } + } _ => Expr::BinaryExpr { left: Box::new(left), op: Operator::Eq, @@ -102,14 +120,28 @@ fn optimize_expr(e: &Expr) -> Expr { }, }, Operator::NotEq => match (&left, &right) { - (Expr::Literal(ScalarValue::Boolean(b)), _) => match b { - Some(false) | None => right, - Some(true) => Expr::Not(Box::new(right)), - }, - (_, Expr::Literal(ScalarValue::Boolean(b))) => match b { - Some(false) | None => left, - Some(true) => Expr::Not(Box::new(left)), - }, + ( + Expr::Literal(ScalarValue::Boolean(l)), + Expr::Literal(ScalarValue::Boolean(r)), + ) => Expr::Literal(ScalarValue::Boolean(Some( + l.unwrap_or(false) != r.unwrap_or(false), + ))), + (Expr::Literal(ScalarValue::Boolean(b)), _) + if right.get_type(schema)? == DataType::Boolean => + { + match b { + Some(false) | None => right, + Some(true) => Expr::Not(Box::new(right)), + } + } + (_, Expr::Literal(ScalarValue::Boolean(b))) + if left.get_type(schema)? == DataType::Boolean => + { + match b { + Some(false) | None => left, + Some(true) => Expr::Not(Box::new(left)), + } + } _ => Expr::BinaryExpr { left: Box::new(left), op: Operator::NotEq, @@ -123,7 +155,7 @@ fn optimize_expr(e: &Expr) -> Expr { }, } } - Expr::Not(expr) => Expr::Not(Box::new(optimize_expr(&expr))), + Expr::Not(expr) => Expr::Not(Box::new(optimize_expr(&expr, schema)?)), Expr::Case { expr, when_then_expr, @@ -135,8 +167,10 @@ fn optimize_expr(e: &Expr) -> Expr { expr: None, when_then_expr: when_then_expr .iter() - .map(|(when, then)| (Box::new(optimize_expr(when)), then.clone())) - .collect(), + .map(|(when, then)| { + Ok((Box::new(optimize_expr(when, schema)?), then.clone())) + }) + .collect::>()?, else_expr: else_expr.clone(), } } else { @@ -161,33 +195,163 @@ fn optimize_expr(e: &Expr) -> Expr { | Expr::AggregateUDF { .. } | Expr::Sort { .. } | Expr::Wildcard => e.clone(), - } + }) } #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{col, lit, LogicalPlanBuilder}; - use crate::test::*; + use crate::logical_plan::{col, lit, DFField, DFSchema, LogicalPlanBuilder}; + + use arrow::datatypes::*; + + fn test_table_scan() -> Result { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Boolean, false), + Field::new("d", DataType::UInt32, false), + ]); + LogicalPlanBuilder::scan_empty("test", &schema, None)?.build() + } + + fn expr_test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::new(vec![ + DFField::new(None, "c1", DataType::Utf8, true), + DFField::new(None, "c2", DataType::Boolean, true), + ]) + .unwrap(), + ) + } #[test] fn optimize_expr_eq() -> Result<()> { + let schema = expr_test_schema(); + assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit(true)), + op: Operator::Eq, + right: Box::new(lit(true)), + }, + &schema + )?, + lit(true), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit(true)), + op: Operator::Eq, + right: Box::new(lit(false)), + }, + &schema + )?, + lit(false), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c2")), + op: Operator::Eq, + right: Box::new(lit(true)), + }, + &schema + )?, + col("c2"), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c2")), + op: Operator::Eq, + right: Box::new(lit(false)), + }, + &schema + )?, + col("c2").not(), + ); + + Ok(()) + } + + #[test] + fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> { + let schema = expr_test_schema(); + + // when one of the operand is not of boolean type, folding the other boolean constant will + // change return type of expression to non-boolean. + assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); + assert_eq!( - optimize_expr(&Expr::BinaryExpr { + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::Eq, + right: Box::new(lit(true)), + }, + &schema + )?, + Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::Eq, + right: Box::new(lit(true)), + }, + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::Eq, + right: Box::new(lit(false)), + }, + &schema + )?, + Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::Eq, + right: Box::new(lit(false)), + }, + ); + + // test constant operands + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit(1)), + op: Operator::Eq, + right: Box::new(lit(true)), + }, + &schema + )?, + Expr::BinaryExpr { left: Box::new(lit(1)), op: Operator::Eq, right: Box::new(lit(true)), - }), - lit(1), + }, ); assert_eq!( - optimize_expr(&Expr::BinaryExpr { + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit("a")), + op: Operator::Eq, + right: Box::new(lit(false)), + }, + &schema + )?, + Expr::BinaryExpr { left: Box::new(lit("a")), op: Operator::Eq, right: Box::new(lit(false)), - }), - lit("a").not(), + }, ); Ok(()) @@ -195,45 +359,132 @@ mod tests { #[test] fn optimize_expr_not_eq() -> Result<()> { + let schema = expr_test_schema(); + assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); + assert_eq!( - optimize_expr(&Expr::BinaryExpr { - left: Box::new(lit(1)), - op: Operator::NotEq, - right: Box::new(lit(true)), - }), - lit(1).not(), + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c2")), + op: Operator::NotEq, + right: Box::new(lit(true)), + }, + &schema + )?, + col("c2").not(), ); assert_eq!( - optimize_expr(&Expr::BinaryExpr { - left: Box::new(lit("a")), - op: Operator::NotEq, - right: Box::new(lit(false)), - }), - lit("a"), + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c2")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }, + &schema + )?, + col("c2"), + ); + + // test constant + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit(true)), + op: Operator::NotEq, + right: Box::new(lit(true)), + }, + &schema + )?, + lit(false), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit(true)), + op: Operator::NotEq, + right: Box::new(lit(false)), + }, + &schema + )?, + lit(true), ); Ok(()) } #[test] - fn optimize_expr_not_not_eq() -> Result<()> { + fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> { + let schema = expr_test_schema(); + + // when one of the operand is not of boolean type, folding the other boolean constant will + // change return type of expression to non-boolean. + assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); + assert_eq!( - optimize_expr(&Expr::Not(Box::new(Expr::BinaryExpr { + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::NotEq, + right: Box::new(lit(true)), + }, + &schema + )?, + Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::NotEq, + right: Box::new(lit(true)), + }, + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }, + &schema + )?, + Expr::BinaryExpr { + left: Box::new(col("c1")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }, + ); + + // test constants + assert_eq!( + optimize_expr( + &Expr::Not(Box::new(Expr::BinaryExpr { + left: Box::new(lit(1)), + op: Operator::NotEq, + right: Box::new(lit(true)), + })), + &schema + )?, + Expr::Not(Box::new(Expr::BinaryExpr { left: Box::new(lit(1)), op: Operator::NotEq, right: Box::new(lit(true)), - }))), - lit(1).not().not(), + })), ); assert_eq!( - optimize_expr(&Expr::Not(Box::new(Expr::BinaryExpr { + optimize_expr( + &Expr::Not(Box::new(Expr::BinaryExpr { + left: Box::new(lit("a")), + op: Operator::NotEq, + right: Box::new(lit(false)), + })), + &schema + )?, + Expr::Not(Box::new(Expr::BinaryExpr { left: Box::new(lit("a")), op: Operator::NotEq, right: Box::new(lit(false)), - }))), - lit("a").not(), + })), ); Ok(()) @@ -241,22 +492,27 @@ mod tests { #[test] fn optimize_expr_case_when_then_else() -> Result<()> { + let schema = expr_test_schema(); + assert_eq!( - optimize_expr(&Box::new(Expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(Expr::BinaryExpr { - left: Box::new(lit("a")), - op: Operator::NotEq, - right: Box::new(lit(false)), - }), - Box::new(lit("ok")), - )], - else_expr: Some(Box::new(lit("not ok"))), - })), + optimize_expr( + &Box::new(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::BinaryExpr { + left: Box::new(col("c2")), + op: Operator::NotEq, + right: Box::new(lit(false)), + }), + Box::new(lit("ok")), + )], + else_expr: Some(Box::new(lit("not ok"))), + }), + &schema + )?, Expr::Case { expr: None, - when_then_expr: vec![(Box::new(lit("a")), Box::new(lit("ok")))], + when_then_expr: vec![(Box::new(col("c2")), Box::new(lit("ok")))], else_expr: Some(Box::new(lit("not ok"))), } ); @@ -265,7 +521,7 @@ mod tests { } fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let mut rule = ConstantFolding::new(); + let rule = ConstantFolding::new(); let optimized_plan = rule.optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); @@ -275,15 +531,15 @@ mod tests { fn optimize_plan_eq_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .filter(col("a").eq(lit(true)))? - .filter(col("b").eq(lit(false)))? - .project(vec![col("a")])? + .filter(col("b").eq(lit(true)))? + .filter(col("c").eq(lit(false)))? + .project(&[col("a")])? .build()?; let expected = "\ Projection: #a\ - \n Filter: NOT #b\ - \n Filter: #a\ + \n Filter: NOT #c\ + \n Filter: #b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -294,17 +550,17 @@ mod tests { fn optimize_plan_not_eq_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .filter(col("a").not_eq(lit(true)))? - .filter(col("b").not_eq(lit(false)))? + .filter(col("b").not_eq(lit(true)))? + .filter(col("c").not_eq(lit(false)))? .limit(1)? - .project(vec![col("a")])? + .project(&[col("a")])? .build()?; let expected = "\ Projection: #a\ \n Limit: 1\ - \n Filter: #b\ - \n Filter: NOT #a\ + \n Filter: #c\ + \n Filter: NOT #b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -315,13 +571,13 @@ mod tests { fn optimize_plan_and_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .filter(col("a").not_eq(lit(true)).and(col("b").eq(lit(true))))? - .project(vec![col("a")])? + .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))? + .project(&[col("a")])? .build()?; let expected = "\ Projection: #a\ - \n Filter: NOT #a And #b\ + \n Filter: NOT #b And #c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -332,13 +588,13 @@ mod tests { fn optimize_plan_or_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .filter(col("a").not_eq(lit(true)).or(col("b").eq(lit(false))))? - .project(vec![col("a")])? + .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))? + .project(&[col("a")])? .build()?; let expected = "\ Projection: #a\ - \n Filter: NOT #a Or NOT #b\ + \n Filter: NOT #b Or NOT #c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -349,13 +605,13 @@ mod tests { fn optimize_plan_not_expr() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .filter(col("a").eq(lit(false)).not())? - .project(vec![col("a")])? + .filter(col("b").eq(lit(false)).not())? + .project(&[col("a")])? .build()?; let expected = "\ Projection: #a\ - \n Filter: NOT NOT #a\ + \n Filter: NOT NOT #b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); From 91cdc0c5ee500a6d52be984347763e949d69c202 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 7 Feb 2021 12:04:28 -0800 Subject: [PATCH 05/11] optimize !!expr to expr --- .../src/optimizer/constant_folding.rs | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 03f7dcded0f..1262d8e3f99 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -31,10 +31,11 @@ use crate::scalar::ScalarValue; /// Optimizer that simplifies comparison expressions involving boolean literals. /// /// Recursively go through all expressionss and simplify the following cases: -/// * `expr = ture` and `expr != false` to `expr` when `expr` is of boolean type +/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type /// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type /// * `true = true` and `false = false` to `true` /// * `false = true` and `true = false` to `false` +/// * `!!expr` to `expr` pub struct ConstantFolding {} impl ConstantFolding { @@ -155,7 +156,10 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { }, } } - Expr::Not(expr) => Expr::Not(Box::new(optimize_expr(&expr, schema)?)), + Expr::Not(expr) => match &**expr { + Expr::Not(inner) => optimize_expr(&inner, schema)?, + _ => Expr::Not(Box::new(optimize_expr(&expr, schema)?)), + }, Expr::Case { expr, when_then_expr, @@ -225,6 +229,22 @@ mod tests { ) } + #[test] + fn optimize_expr_not_not() -> Result<()> { + let schema = expr_test_schema(); + assert_eq!( + optimize_expr( + &Expr::Not(Box::new(Expr::Not(Box::new(Expr::Not(Box::new(col( + "c2" + ))))))), + &schema + )?, + col("c2").not(), + ); + + Ok(()) + } + #[test] fn optimize_expr_eq() -> Result<()> { let schema = expr_test_schema(); From f49bd12da816524d832c992485422f476204b42f Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 7 Feb 2021 12:18:22 -0800 Subject: [PATCH 06/11] handle null comparision --- .../src/optimizer/constant_folding.rs | 86 ++++++++++++++++--- 1 file changed, 76 insertions(+), 10 deletions(-) diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 1262d8e3f99..9cdd3158c64 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -36,6 +36,7 @@ use crate::scalar::ScalarValue; /// * `true = true` and `false = false` to `true` /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` +/// * `expr = null` and `expr != null` to `null` pub struct ConstantFolding {} impl ConstantFolding { @@ -95,15 +96,19 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { ( Expr::Literal(ScalarValue::Boolean(l)), Expr::Literal(ScalarValue::Boolean(r)), - ) => Expr::Literal(ScalarValue::Boolean(Some( - l.unwrap_or(false) == r.unwrap_or(false), - ))), + ) => match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l == r))) + } + _ => Expr::Literal(ScalarValue::Boolean(None)), + }, (Expr::Literal(ScalarValue::Boolean(b)), _) if right.get_type(schema)? == DataType::Boolean => { match b { Some(true) => right, - Some(false) | None => Expr::Not(Box::new(right)), + Some(false) => Expr::Not(Box::new(right)), + None => Expr::Literal(ScalarValue::Boolean(None)), } } (_, Expr::Literal(ScalarValue::Boolean(b))) @@ -111,7 +116,8 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { { match b { Some(true) => left, - Some(false) | None => Expr::Not(Box::new(left)), + Some(false) => Expr::Not(Box::new(left)), + None => Expr::Literal(ScalarValue::Boolean(None)), } } _ => Expr::BinaryExpr { @@ -124,23 +130,28 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { ( Expr::Literal(ScalarValue::Boolean(l)), Expr::Literal(ScalarValue::Boolean(r)), - ) => Expr::Literal(ScalarValue::Boolean(Some( - l.unwrap_or(false) != r.unwrap_or(false), - ))), + ) => match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l != r))) + } + _ => Expr::Literal(ScalarValue::Boolean(None)), + }, (Expr::Literal(ScalarValue::Boolean(b)), _) if right.get_type(schema)? == DataType::Boolean => { match b { - Some(false) | None => right, Some(true) => Expr::Not(Box::new(right)), + Some(false) => right, + None => Expr::Literal(ScalarValue::Boolean(None)), } } (_, Expr::Literal(ScalarValue::Boolean(b))) if left.get_type(schema)? == DataType::Boolean => { match b { - Some(false) | None => left, Some(true) => Expr::Not(Box::new(left)), + Some(false) => left, + None => Expr::Literal(ScalarValue::Boolean(None)), } } _ => Expr::BinaryExpr { @@ -245,6 +256,61 @@ mod tests { Ok(()) } + #[test] + fn optimize_expr_null_comparision() -> Result<()> { + let schema = expr_test_schema(); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(lit(true)), + op: Operator::Eq, + right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), + }, + &schema + )?, + Expr::Literal(ScalarValue::Boolean(None)), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(Expr::Literal(ScalarValue::Boolean(None))), + op: Operator::NotEq, + right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), + }, + &schema + )?, + Expr::Literal(ScalarValue::Boolean(None)), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(col("c2")), + op: Operator::NotEq, + right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), + }, + &schema + )?, + Expr::Literal(ScalarValue::Boolean(None)), + ); + + assert_eq!( + optimize_expr( + &Expr::BinaryExpr { + left: Box::new(Expr::Literal(ScalarValue::Boolean(None))), + op: Operator::Eq, + right: Box::new(col("c2")), + }, + &schema + )?, + Expr::Literal(ScalarValue::Boolean(None)), + ); + + Ok(()) + } + #[test] fn optimize_expr_eq() -> Result<()> { let schema = expr_test_schema(); From 00baa002a60dcec9d7a254591dad3620965c9374 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 7 Feb 2021 14:21:44 -0800 Subject: [PATCH 07/11] optimize then and else_expr branches for case expression --- .../src/optimizer/constant_folding.rs | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 9cdd3158c64..81c31d43456 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -162,7 +162,7 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { }, _ => Expr::BinaryExpr { left: Box::new(left), - op: op.clone(), + op: *op, right: Box::new(right), }, } @@ -183,10 +183,16 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { when_then_expr: when_then_expr .iter() .map(|(when, then)| { - Ok((Box::new(optimize_expr(when, schema)?), then.clone())) + Ok(( + Box::new(optimize_expr(when, schema)?), + Box::new(optimize_expr(then, schema)?), + )) }) .collect::>()?, - else_expr: else_expr.clone(), + else_expr: match else_expr { + Some(e) => Some(Box::new(optimize_expr(e, schema)?)), + None => None, + }, } } else { // when base expression is specified, when_then_expr conditions are literal values @@ -590,16 +596,19 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(false)), }), - Box::new(lit("ok")), + Box::new(lit("ok").eq(lit(true))), )], - else_expr: Some(Box::new(lit("not ok"))), + else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), &schema )?, Expr::Case { expr: None, - when_then_expr: vec![(Box::new(col("c2")), Box::new(lit("ok")))], - else_expr: Some(Box::new(lit("not ok"))), + when_then_expr: vec![( + Box::new(col("c2")), + Box::new(lit("ok").eq(lit(true))) + )], + else_expr: Some(Box::new(col("c2"))), } ); From 1e573703d24ebc6a804998f5aa27f788793743df Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 13 Feb 2021 00:14:38 -0800 Subject: [PATCH 08/11] recursive into all logical plan nodes and expression types --- rust/datafusion/src/logical_plan/plan.rs | 34 +++ .../src/optimizer/constant_folding.rs | 223 ++++++++++++++---- 2 files changed, 205 insertions(+), 52 deletions(-) diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index 2afdefda1b0..c04bdb37187 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -197,6 +197,40 @@ impl LogicalPlan { } } + /// Get a vector of references to all schemas in every node of the logical plan + pub fn all_schemas(&self) -> Vec<&DFSchemaRef> { + match self { + LogicalPlan::TableScan { + projected_schema, .. + } => vec![&projected_schema], + LogicalPlan::Aggregate { input, schema, .. } + | LogicalPlan::Projection { input, schema, .. } => { + let mut schemas = input.all_schemas(); + schemas.insert(0, &schema); + schemas + } + LogicalPlan::Join { + left, + right, + schema, + .. + } => { + let mut schemas = left.all_schemas(); + schemas.extend(right.all_schemas()); + schemas.insert(0, &schema); + schemas + } + LogicalPlan::Extension { node } => vec![&node.schema()], + LogicalPlan::Explain { schema, .. } + | LogicalPlan::EmptyRelation { schema, .. } + | LogicalPlan::CreateExternalTable { schema, .. } => vec![&schema], + LogicalPlan::Limit { input, .. } + | LogicalPlan::Repartition { input, .. } + | LogicalPlan::Sort { input, .. } + | LogicalPlan::Filter { input, .. } => input.all_schemas(), + } + } + /// Returns the (fixed) output schema for explain plans pub fn explain_schema() -> SchemaRef { SchemaRef::new(Schema::new(vec![ diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 81c31d43456..a20017f63e7 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -50,21 +50,19 @@ impl OptimizerRule for ConstantFolding { fn optimize(&self, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: optimize_expr(predicate, plan.schema())?, + predicate: optimize_expr(predicate, &plan.all_schemas())?, input: Arc::new(self.optimize(input)?), }), // Rest: recurse into plan, apply optimization where possible LogicalPlan::Projection { .. } | LogicalPlan::Aggregate { .. } - | LogicalPlan::Limit { .. } | LogicalPlan::Repartition { .. } | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Extension { .. } | LogicalPlan::Sort { .. } | LogicalPlan::Explain { .. } + | LogicalPlan::Limit { .. } | LogicalPlan::Join { .. } => { - let expr = utils::expressions(plan); - // apply the optimization to all inputs of the plan let inputs = utils::inputs(plan); let new_inputs = inputs @@ -72,6 +70,12 @@ impl OptimizerRule for ConstantFolding { .map(|plan| self.optimize(plan)) .collect::>>()?; + let schemas = plan.all_schemas(); + let expr = utils::expressions(plan) + .iter() + .map(|e| optimize_expr(e, &schemas)) + .collect::>>()?; + utils::from_plan(plan, &expr, &new_inputs) } LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } => { @@ -85,12 +89,27 @@ impl OptimizerRule for ConstantFolding { } } +fn is_boolean_type(expr: &Expr, schemas: &[&DFSchemaRef]) -> bool { + for schema in schemas { + match expr.get_type(schema) { + Ok(dt) if dt == DataType::Boolean => { + return true; + } + _ => { + continue; + } + } + } + + false +} + /// Recursively transverses the logical plan. -fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { +fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { Ok(match e { Expr::BinaryExpr { left, op, right } => { - let left = optimize_expr(left, schema)?; - let right = optimize_expr(right, schema)?; + let left = optimize_expr(left, schemas)?; + let right = optimize_expr(right, schemas)?; match op { Operator::Eq => match (&left, &right) { ( @@ -103,7 +122,7 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { _ => Expr::Literal(ScalarValue::Boolean(None)), }, (Expr::Literal(ScalarValue::Boolean(b)), _) - if right.get_type(schema)? == DataType::Boolean => + if is_boolean_type(&right, schemas) => { match b { Some(true) => right, @@ -112,7 +131,7 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { } } (_, Expr::Literal(ScalarValue::Boolean(b))) - if left.get_type(schema)? == DataType::Boolean => + if is_boolean_type(&left, schemas) => { match b { Some(true) => left, @@ -137,7 +156,7 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { _ => Expr::Literal(ScalarValue::Boolean(None)), }, (Expr::Literal(ScalarValue::Boolean(b)), _) - if right.get_type(schema)? == DataType::Boolean => + if is_boolean_type(&right, schemas) => { match b { Some(true) => Expr::Not(Box::new(right)), @@ -146,7 +165,7 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { } } (_, Expr::Literal(ScalarValue::Boolean(b))) - if left.get_type(schema)? == DataType::Boolean => + if is_boolean_type(&left, schemas) => { match b { Some(true) => Expr::Not(Box::new(left)), @@ -168,8 +187,8 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { } } Expr::Not(expr) => match &**expr { - Expr::Not(inner) => optimize_expr(&inner, schema)?, - _ => Expr::Not(Box::new(optimize_expr(&expr, schema)?)), + Expr::Not(inner) => optimize_expr(&inner, schemas)?, + _ => Expr::Not(Box::new(optimize_expr(&expr, schemas)?)), }, Expr::Case { expr, @@ -184,13 +203,13 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { .iter() .map(|(when, then)| { Ok(( - Box::new(optimize_expr(when, schema)?), - Box::new(optimize_expr(then, schema)?), + Box::new(optimize_expr(when, schemas)?), + Box::new(optimize_expr(then, schemas)?), )) }) .collect::>()?, else_expr: match else_expr { - Some(e) => Some(Box::new(optimize_expr(e, schema)?)), + Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), None => None, }, } @@ -200,21 +219,84 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { e.clone() } } - Expr::Alias { .. } - | Expr::Negative { .. } - | Expr::Column { .. } - | Expr::InList { .. } - | Expr::IsNotNull { .. } - | Expr::IsNull { .. } - | Expr::Cast { .. } + Expr::Alias(expr, name) => { + Expr::Alias(Box::new(optimize_expr(expr, schemas)?), name.clone()) + } + Expr::Negative(expr) => Expr::Negative(Box::new(optimize_expr(expr, schemas)?)), + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: Box::new(optimize_expr(expr, schemas)?), + list: list + .iter() + .map(|e| optimize_expr(e, schemas)) + .collect::>()?, + negated: *negated, + }, + Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(optimize_expr(expr, schemas)?)), + Expr::IsNull(expr) => Expr::IsNull(Box::new(optimize_expr(expr, schemas)?)), + Expr::Cast { expr, data_type } => Expr::Cast { + expr: Box::new(optimize_expr(expr, schemas)?), + data_type: data_type.clone(), + }, + Expr::Between { + expr, + negated, + low, + high, + } => Expr::Between { + expr: Box::new(optimize_expr(expr, schemas)?), + negated: *negated, + low: Box::new(optimize_expr(low, schemas)?), + high: Box::new(optimize_expr(high, schemas)?), + }, + Expr::ScalarFunction { fun, args } => Expr::ScalarFunction { + fun: fun.clone(), + args: args + .iter() + .map(|e| optimize_expr(e, schemas)) + .collect::>()?, + }, + Expr::ScalarUDF { fun, args } => Expr::ScalarUDF { + fun: fun.clone(), + args: args + .iter() + .map(|e| optimize_expr(e, schemas)) + .collect::>()?, + }, + Expr::AggregateFunction { + fun, + args, + distinct, + } => Expr::AggregateFunction { + fun: fun.clone(), + args: args + .iter() + .map(|e| optimize_expr(e, schemas)) + .collect::>()?, + distinct: *distinct, + }, + Expr::AggregateUDF { fun, args } => Expr::AggregateUDF { + fun: fun.clone(), + args: args + .iter() + .map(|e| optimize_expr(e, schemas)) + .collect::>()?, + }, + Expr::Sort { + expr, + asc, + nulls_first, + } => Expr::Sort { + expr: Box::new(optimize_expr(expr, schemas)?), + asc: *asc, + nulls_first: *nulls_first, + }, + Expr::Column { .. } | Expr::ScalarVariable { .. } - | Expr::Between { .. } | Expr::Literal { .. } - | Expr::ScalarFunction { .. } - | Expr::ScalarUDF { .. } - | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } - | Expr::Sort { .. } | Expr::Wildcard => e.clone(), }) } @@ -222,7 +304,9 @@ fn optimize_expr(e: &Expr, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{col, lit, DFField, DFSchema, LogicalPlanBuilder}; + use crate::logical_plan::{ + col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, + }; use arrow::datatypes::*; @@ -254,7 +338,7 @@ mod tests { &Expr::Not(Box::new(Expr::Not(Box::new(Expr::Not(Box::new(col( "c2" ))))))), - &schema + &[&schema], )?, col("c2").not(), ); @@ -273,7 +357,7 @@ mod tests { op: Operator::Eq, right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), }, - &schema + &[&schema], )?, Expr::Literal(ScalarValue::Boolean(None)), ); @@ -285,7 +369,7 @@ mod tests { op: Operator::NotEq, right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), }, - &schema + &[&schema], )?, Expr::Literal(ScalarValue::Boolean(None)), ); @@ -297,7 +381,7 @@ mod tests { op: Operator::NotEq, right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), }, - &schema + &[&schema], )?, Expr::Literal(ScalarValue::Boolean(None)), ); @@ -309,7 +393,7 @@ mod tests { op: Operator::Eq, right: Box::new(col("c2")), }, - &schema + &[&schema], )?, Expr::Literal(ScalarValue::Boolean(None)), ); @@ -329,7 +413,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, lit(true), ); @@ -341,7 +425,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, lit(false), ); @@ -353,7 +437,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, col("c2"), ); @@ -365,7 +449,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, col("c2").not(), ); @@ -388,7 +472,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, Expr::BinaryExpr { left: Box::new(col("c1")), @@ -404,7 +488,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, Expr::BinaryExpr { left: Box::new(col("c1")), @@ -421,7 +505,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, Expr::BinaryExpr { left: Box::new(lit(1)), @@ -437,7 +521,7 @@ mod tests { op: Operator::Eq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, Expr::BinaryExpr { left: Box::new(lit("a")), @@ -461,7 +545,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, col("c2").not(), ); @@ -473,7 +557,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, col("c2"), ); @@ -486,7 +570,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, lit(false), ); @@ -498,7 +582,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, lit(true), ); @@ -521,7 +605,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(true)), }, - &schema + &[&schema], )?, Expr::BinaryExpr { left: Box::new(col("c1")), @@ -537,7 +621,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(false)), }, - &schema + &[&schema], )?, Expr::BinaryExpr { left: Box::new(col("c1")), @@ -554,7 +638,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(true)), })), - &schema + &[&schema], )?, Expr::Not(Box::new(Expr::BinaryExpr { left: Box::new(lit(1)), @@ -570,7 +654,7 @@ mod tests { op: Operator::NotEq, right: Box::new(lit(false)), })), - &schema + &[&schema], )?, Expr::Not(Box::new(Expr::BinaryExpr { left: Box::new(lit("a")), @@ -600,7 +684,7 @@ mod tests { )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), - &schema + &[&schema], )?, Expr::Case { expr: None, @@ -712,4 +796,39 @@ mod tests { assert_optimized_plan_eq(&plan, expected); Ok(()) } + + #[test] + fn optimize_plan_support_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .project(&[col("a"), col("d"), col("b").eq(lit(false))])? + .build()?; + + let expected = "\ + Projection: #a, #d, NOT #b\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn optimize_plan_support_aggregate() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(&table_scan) + .project(&[col("a"), col("c"), col("b")])? + .aggregate( + &[col("a"), col("c")], + &[max(col("b").eq(lit(true))), min(col("b"))], + )? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b), MIN(#b)]]\ + \n Projection: #a, #c, #b\ + \n TableScan: test projection=None"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } } From c68794a591324057885d1f7c5b7637284540b9fd Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Mon, 15 Feb 2021 13:22:45 -0800 Subject: [PATCH 09/11] address review feedback --- .../src/optimizer/constant_folding.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index a20017f63e7..aacf7295cbb 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -48,6 +48,12 @@ impl ConstantFolding { impl OptimizerRule for ConstantFolding { fn optimize(&self, plan: &LogicalPlan) -> Result { + // We need to pass down the all schemas within the plan tree to `optimize_expr` in order to + // to evaluate expression types. For example, a projection plan's schema will only include + // projected columns. With just the projected schema, it's not possible to infer types for + // expressions that references non-projected columns within the same project plan or its + // children plans. + match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { predicate: optimize_expr(predicate, &plan.all_schemas())?, @@ -91,20 +97,15 @@ impl OptimizerRule for ConstantFolding { fn is_boolean_type(expr: &Expr, schemas: &[&DFSchemaRef]) -> bool { for schema in schemas { - match expr.get_type(schema) { - Ok(dt) if dt == DataType::Boolean => { - return true; - } - _ => { - continue; - } + if let Ok(DataType::Boolean) = expr.get_type(schema) { + return true; } } false } -/// Recursively transverses the logical plan. +/// Recursively transverses the expression tree. fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { Ok(match e { Expr::BinaryExpr { left, op, right } => { From 8e9605a96136ef463d011bd25fb719a28ec25cc5 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Mon, 15 Feb 2021 13:55:19 -0800 Subject: [PATCH 10/11] simplify tests --- rust/datafusion/src/logical_plan/expr.rs | 6 + .../src/optimizer/constant_folding.rs | 259 ++++-------------- 2 files changed, 55 insertions(+), 210 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 7f358cb31b0..cfb7250042d 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -770,6 +770,12 @@ impl Literal for String { } } +impl Literal for ScalarValue { + fn lit(&self) -> Expr { + Expr::Literal(self.clone()) + } +} + macro_rules! make_literal { ($TYPE:ty, $SCALAR:ident) => { #[allow(missing_docs)] diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index aacf7295cbb..9b06184220c 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -335,12 +335,7 @@ mod tests { fn optimize_expr_not_not() -> Result<()> { let schema = expr_test_schema(); assert_eq!( - optimize_expr( - &Expr::Not(Box::new(Expr::Not(Box::new(Expr::Not(Box::new(col( - "c2" - ))))))), - &[&schema], - )?, + optimize_expr(&col("c2").not().not().not(), &[&schema])?, col("c2").not(), ); @@ -351,52 +346,34 @@ mod tests { fn optimize_expr_null_comparision() -> Result<()> { let schema = expr_test_schema(); + // x = null is always null assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit(true)), - op: Operator::Eq, - right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), - }, - &[&schema], - )?, - Expr::Literal(ScalarValue::Boolean(None)), + optimize_expr(&lit(true).eq(lit(ScalarValue::Boolean(None))), &[&schema])?, + lit(ScalarValue::Boolean(None)), ); + // null != null is always null assert_eq!( optimize_expr( - &Expr::BinaryExpr { - left: Box::new(Expr::Literal(ScalarValue::Boolean(None))), - op: Operator::NotEq, - right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), - }, + &lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))), &[&schema], )?, - Expr::Literal(ScalarValue::Boolean(None)), + lit(ScalarValue::Boolean(None)), ); + // x != null is always null assert_eq!( optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c2")), - op: Operator::NotEq, - right: Box::new(Expr::Literal(ScalarValue::Boolean(None))), - }, + &col("c2").not_eq(lit(ScalarValue::Boolean(None))), &[&schema], )?, - Expr::Literal(ScalarValue::Boolean(None)), + lit(ScalarValue::Boolean(None)), ); + // null = x is always null assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(Expr::Literal(ScalarValue::Boolean(None))), - op: Operator::Eq, - right: Box::new(col("c2")), - }, - &[&schema], - )?, - Expr::Literal(ScalarValue::Boolean(None)), + optimize_expr(&lit(ScalarValue::Boolean(None)).eq(col("c2")), &[&schema])?, + lit(ScalarValue::Boolean(None)), ); Ok(()) @@ -407,51 +384,27 @@ mod tests { let schema = expr_test_schema(); assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); + // true = ture -> true assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit(true)), - op: Operator::Eq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, + optimize_expr(&lit(true).eq(lit(true)), &[&schema])?, lit(true), ); + // true = false -> false assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit(true)), - op: Operator::Eq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, + optimize_expr(&lit(true).eq(lit(false)), &[&schema])?, lit(false), ); + // c2 = true -> c2 assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c2")), - op: Operator::Eq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, + optimize_expr(&col("c2").eq(lit(true)), &[&schema])?, col("c2"), ); + // c2 = false => !c2 assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c2")), - op: Operator::Eq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, + optimize_expr(&col("c2").eq(lit(false)), &[&schema])?, col("c2").not(), ); @@ -462,73 +415,33 @@ mod tests { fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); - // when one of the operand is not of boolean type, folding the other boolean constant will + // When one of the operand is not of boolean type, folding the other boolean constant will // change return type of expression to non-boolean. + // + // Make sure c1 column to be used in tests is not boolean type assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); + // don't fold c1 = true assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::Eq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, - Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::Eq, - right: Box::new(lit(true)), - }, + optimize_expr(&col("c1").eq(lit(true)), &[&schema])?, + col("c1").eq(lit(true)), ); + // don't fold c1 = false assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::Eq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, - Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::Eq, - right: Box::new(lit(false)), - }, + optimize_expr(&col("c1").eq(lit(false)), &[&schema],)?, + col("c1").eq(lit(false)), ); // test constant operands assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit(1)), - op: Operator::Eq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, - Expr::BinaryExpr { - left: Box::new(lit(1)), - op: Operator::Eq, - right: Box::new(lit(true)), - }, + optimize_expr(&lit(1).eq(lit(true)), &[&schema],)?, + lit(1).eq(lit(true)), ); assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit("a")), - op: Operator::Eq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, - Expr::BinaryExpr { - left: Box::new(lit("a")), - op: Operator::Eq, - right: Box::new(lit(false)), - }, + optimize_expr(&lit("a").eq(lit(false)), &[&schema],)?, + lit("a").eq(lit(false)), ); Ok(()) @@ -539,52 +452,26 @@ mod tests { let schema = expr_test_schema(); assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); + // c2 != true -> !c2 assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c2")), - op: Operator::NotEq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, + optimize_expr(&col("c2").not_eq(lit(true)), &[&schema])?, col("c2").not(), ); + // c2 != false -> c2 assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c2")), - op: Operator::NotEq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, + optimize_expr(&col("c2").not_eq(lit(false)), &[&schema])?, col("c2"), ); // test constant assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit(true)), - op: Operator::NotEq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, + optimize_expr(&lit(true).not_eq(lit(true)), &[&schema])?, lit(false), ); assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(lit(true)), - op: Operator::NotEq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, + optimize_expr(&lit(true).not_eq(lit(false)), &[&schema])?, lit(true), ); @@ -600,68 +487,24 @@ mod tests { assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::NotEq, - right: Box::new(lit(true)), - }, - &[&schema], - )?, - Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::NotEq, - right: Box::new(lit(true)), - }, + optimize_expr(&col("c1").not_eq(lit(true)), &[&schema])?, + col("c1").not_eq(lit(true)), ); assert_eq!( - optimize_expr( - &Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::NotEq, - right: Box::new(lit(false)), - }, - &[&schema], - )?, - Expr::BinaryExpr { - left: Box::new(col("c1")), - op: Operator::NotEq, - right: Box::new(lit(false)), - }, + optimize_expr(&col("c1").not_eq(lit(false)), &[&schema])?, + col("c1").not_eq(lit(false)), ); // test constants assert_eq!( - optimize_expr( - &Expr::Not(Box::new(Expr::BinaryExpr { - left: Box::new(lit(1)), - op: Operator::NotEq, - right: Box::new(lit(true)), - })), - &[&schema], - )?, - Expr::Not(Box::new(Expr::BinaryExpr { - left: Box::new(lit(1)), - op: Operator::NotEq, - right: Box::new(lit(true)), - })), + optimize_expr(&lit(1).not_eq(lit(true)), &[&schema])?, + lit(1).not_eq(lit(true)), ); assert_eq!( - optimize_expr( - &Expr::Not(Box::new(Expr::BinaryExpr { - left: Box::new(lit("a")), - op: Operator::NotEq, - right: Box::new(lit(false)), - })), - &[&schema], - )?, - Expr::Not(Box::new(Expr::BinaryExpr { - left: Box::new(lit("a")), - op: Operator::NotEq, - right: Box::new(lit(false)), - })), + optimize_expr(&lit("a").not_eq(lit(false)), &[&schema],)?, + lit("a").not_eq(lit(false)), ); Ok(()) @@ -676,11 +519,7 @@ mod tests { &Box::new(Expr::Case { expr: None, when_then_expr: vec![( - Box::new(Expr::BinaryExpr { - left: Box::new(col("c2")), - op: Operator::NotEq, - right: Box::new(lit(false)), - }), + Box::new(col("c2").not_eq(lit(false))), Box::new(lit("ok").eq(lit(true))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), From 874955bc98a986927a8bfc2b68a8935b83108314 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Mon, 15 Feb 2021 14:08:00 -0800 Subject: [PATCH 11/11] optimize case expression when base expression is specified --- .../src/optimizer/constant_folding.rs | 41 +++++++++---------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 9b06184220c..86cadf6405e 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -196,28 +196,25 @@ fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { when_then_expr, else_expr, } => { - if expr.is_none() { - // recurse into CASE WHEN condition expressions - Expr::Case { - expr: None, - when_then_expr: when_then_expr - .iter() - .map(|(when, then)| { - Ok(( - Box::new(optimize_expr(when, schemas)?), - Box::new(optimize_expr(then, schemas)?), - )) - }) - .collect::>()?, - else_expr: match else_expr { - Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), - None => None, - }, - } - } else { - // when base expression is specified, when_then_expr conditions are literal values - // so we can just skip this case - e.clone() + // recurse into CASE WHEN condition expressions + Expr::Case { + expr: match expr { + Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), + None => None, + }, + when_then_expr: when_then_expr + .iter() + .map(|(when, then)| { + Ok(( + Box::new(optimize_expr(when, schemas)?), + Box::new(optimize_expr(then, schemas)?), + )) + }) + .collect::>()?, + else_expr: match else_expr { + Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), + None => None, + }, } } Expr::Alias(expr, name) => {