diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 6dadefea548..775ab64ac14 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -563,6 +563,179 @@ impl Expr { visitor.post_visit(self) } + + /// Performs a depth first walk of an expression and its children + /// to rewrite an expression, consuming `self` producing a new + /// [`Expr`]. + /// + /// Implements a modified version of the [visitor + /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate algorithms from the structure of the `Expr` tree and + /// make it easier to write new, efficient expression + /// transformation algorithms. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// mutatate(Column("foo")) + /// pre_visit(Column("bar")) + /// mutate(Column("bar")) + /// mutate(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that expression are visited, nor is mutate + /// called on that expression + /// + pub fn rewrite(self, rewriter: &mut R) -> Result + where + R: ExprRewriter, + { + if !rewriter.pre_visit(&self)? { + return Ok(self); + }; + + // recurse into all sub expressions(and cover all expression types) + let expr = match self { + Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), + Expr::Column(name) => Expr::Column(name), + Expr::ScalarVariable(names) => Expr::ScalarVariable(names), + Expr::Literal(value) => Expr::Literal(value), + Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { + left: rewrite_boxed(left, rewriter)?, + op, + right: rewrite_boxed(right, rewriter)?, + }, + Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), + Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), + Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), + Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), + Expr::Between { + expr, + low, + high, + negated, + } => Expr::Between { + expr: rewrite_boxed(expr, rewriter)?, + low: rewrite_boxed(low, rewriter)?, + high: rewrite_boxed(high, rewriter)?, + negated, + }, + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr = rewrite_option_box(expr, rewriter)?; + let when_then_expr = when_then_expr + .into_iter() + .map(|(when, then)| { + Ok(( + rewrite_boxed(when, rewriter)?, + rewrite_boxed(then, rewriter)?, + )) + }) + .collect::>>()?; + + let else_expr = rewrite_option_box(else_expr, rewriter)?; + + Expr::Case { + expr, + when_then_expr, + else_expr, + } + } + Expr::Cast { expr, data_type } => Expr::Cast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::Sort { + expr, + asc, + nulls_first, + } => Expr::Sort { + expr: rewrite_boxed(expr, rewriter)?, + asc, + nulls_first, + }, + Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::AggregateFunction { + args, + fun, + distinct, + } => Expr::AggregateFunction { + args: rewrite_vec(args, rewriter)?, + fun, + distinct, + }, + Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: rewrite_boxed(expr, rewriter)?, + list, + negated, + }, + Expr::Wildcard => Expr::Wildcard, + }; + + // now rewrite this expression itself + rewriter.mutate(expr) + } +} + +#[allow(clippy::boxed_local)] +fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + // TODO: It might be possible to avoid an allocation (the + // Box::new) below by reusing the box. + let expr: Expr = *boxed_expr; + let rewritten_expr = expr.rewrite(rewriter)?; + Ok(Box::new(rewritten_expr)) +} + +fn rewrite_option_box( + option_box: Option>, + rewriter: &mut R, +) -> Result>> +where + R: ExprRewriter, +{ + option_box + .map(|expr| rewrite_boxed(expr, rewriter)) + .transpose() +} + +/// rewrite a `Vec` of `Expr`s with the rewriter +fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() } /// Controls how the visitor recursion should proceed. @@ -589,6 +762,22 @@ pub trait ExpressionVisitor: Sized { } } +/// Trait for potentially recursively rewriting an [`Expr`] expression +/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is +/// invoked recursively on all nodes of an expression tree. See the +/// comments on `Expr::rewrite` for details on its use +pub trait ExprRewriter: Sized { + /// Invoked before any children of `expr` are rewritten / + /// visited. Default implementation returns `Ok(true)` + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(true) + } + + /// Invoked after all children of `expr` have been mutated and + /// returns a potentially modified expr. + fn mutate(&mut self, expr: Expr) -> Result; +} + pub struct CaseBuilder { expr: Option>, when_expr: Vec, @@ -1180,4 +1369,74 @@ mod tests { .end(); assert!(maybe_expr.is_err()); } + + #[test] + fn rewriter_visit() { + let mut rewriter = RecordingRewriter::default(); + col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + + assert_eq!( + rewriter.v, + vec![ + "Previsited #state Eq Utf8(\"CO\")", + "Previsited #state", + "Mutated #state", + "Previsited Utf8(\"CO\")", + "Mutated Utf8(\"CO\")", + "Mutated #state Eq Utf8(\"CO\")" + ] + ) + } + + #[derive(Default)] + struct RecordingRewriter { + v: Vec, + } + impl ExprRewriter for RecordingRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + self.v.push(format!("Mutated {:?}", expr)); + Ok(expr) + } + + fn pre_visit(&mut self, expr: &Expr) -> Result { + self.v.push(format!("Previsited {:?}", expr)); + Ok(true) + } + } + + #[test] + fn rewriter_rewrite() { + let mut rewriter = FooBarRewriter {}; + + // rewrites "foo" --> "bar" + let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("bar"))); + + // doesn't wrewrite + let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("baz"))); + } + + /// rewrites all "foo" string literals to "bar" + struct FooBarRewriter {} + impl ExprRewriter for FooBarRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Literal(scalar) => { + if let ScalarValue::Utf8(Some(utf8_val)) = scalar { + let utf8_val = if utf8_val == "foo" { + "bar".to_string() + } else { + utf8_val + }; + Ok(lit(utf8_val)) + } else { + Ok(Expr::Literal(scalar)) + } + } + // otherwise, return the expression unchanged + expr => Ok(expr), + } + } + } } diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 99c35fafd54..90c35dc3a23 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -38,7 +38,7 @@ pub use expr::{ count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, trim, - trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, + trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 86cadf6405e..62f5ee30c62 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; use crate::error::Result; -use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator}; +use crate::logical_plan::{DFSchemaRef, Expr, ExprRewriter, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::scalar::ScalarValue; @@ -53,10 +53,13 @@ impl OptimizerRule for ConstantFolding { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. + let mut rewriter = ConstantRewriter { + schemas: plan.all_schemas(), + }; match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: optimize_expr(predicate, &plan.all_schemas())?, + predicate: predicate.clone().rewrite(&mut rewriter)?, input: Arc::new(self.optimize(input)?), }), // Rest: recurse into plan, apply optimization where possible @@ -76,10 +79,9 @@ impl OptimizerRule for ConstantFolding { .map(|plan| self.optimize(plan)) .collect::>>()?; - let schemas = plan.all_schemas(); let expr = utils::expressions(plan) - .iter() - .map(|e| optimize_expr(e, &schemas)) + .into_iter() + .map(|e| e.rewrite(&mut rewriter)) .collect::>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -95,24 +97,29 @@ impl OptimizerRule for ConstantFolding { } } -fn is_boolean_type(expr: &Expr, schemas: &[&DFSchemaRef]) -> bool { - for schema in schemas { - if let Ok(DataType::Boolean) = expr.get_type(schema) { - return true; +struct ConstantRewriter<'a> { + /// input schemas + schemas: Vec<&'a DFSchemaRef>, +} + +impl<'a> ConstantRewriter<'a> { + fn is_boolean_type(&self, expr: &Expr) -> bool { + for schema in &self.schemas { + if let Ok(DataType::Boolean) = expr.get_type(schema) { + return true; + } } - } - false + false + } } -/// Recursively transverses the expression tree. -fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { - Ok(match e { - Expr::BinaryExpr { left, op, right } => { - let left = optimize_expr(left, schemas)?; - let right = optimize_expr(right, schemas)?; - match op { - Operator::Eq => match (&left, &right) { +impl<'a> ExprRewriter for ConstantRewriter<'a> { + /// rewrite the expression simplifying any constant expressions + fn mutate(&mut self, expr: Expr) -> Result { + let new_expr = match expr { + Expr::BinaryExpr { left, op, right } => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { ( Expr::Literal(ScalarValue::Boolean(l)), Expr::Literal(ScalarValue::Boolean(r)), @@ -123,30 +130,30 @@ fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { _ => Expr::Literal(ScalarValue::Boolean(None)), }, (Expr::Literal(ScalarValue::Boolean(b)), _) - if is_boolean_type(&right, schemas) => + if self.is_boolean_type(&right) => { match b { - Some(true) => right, - Some(false) => Expr::Not(Box::new(right)), + Some(true) => *right, + Some(false) => Expr::Not(right), None => Expr::Literal(ScalarValue::Boolean(None)), } } (_, Expr::Literal(ScalarValue::Boolean(b))) - if is_boolean_type(&left, schemas) => + if self.is_boolean_type(&left) => { match b { - Some(true) => left, - Some(false) => Expr::Not(Box::new(left)), + Some(true) => *left, + Some(false) => Expr::Not(left), None => Expr::Literal(ScalarValue::Boolean(None)), } } _ => Expr::BinaryExpr { - left: Box::new(left), + left, op: Operator::Eq, - right: Box::new(right), + right, }, }, - Operator::NotEq => match (&left, &right) { + Operator::NotEq => match (left.as_ref(), right.as_ref()) { ( Expr::Literal(ScalarValue::Boolean(l)), Expr::Literal(ScalarValue::Boolean(r)), @@ -157,146 +164,46 @@ fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { _ => Expr::Literal(ScalarValue::Boolean(None)), }, (Expr::Literal(ScalarValue::Boolean(b)), _) - if is_boolean_type(&right, schemas) => + if self.is_boolean_type(&right) => { match b { - Some(true) => Expr::Not(Box::new(right)), - Some(false) => right, + Some(true) => Expr::Not(right), + Some(false) => *right, None => Expr::Literal(ScalarValue::Boolean(None)), } } (_, Expr::Literal(ScalarValue::Boolean(b))) - if is_boolean_type(&left, schemas) => + if self.is_boolean_type(&left) => { match b { - Some(true) => Expr::Not(Box::new(left)), - Some(false) => left, + Some(true) => Expr::Not(left), + Some(false) => *left, None => Expr::Literal(ScalarValue::Boolean(None)), } } _ => Expr::BinaryExpr { - left: Box::new(left), + left, op: Operator::NotEq, - right: Box::new(right), + right, }, }, - _ => Expr::BinaryExpr { - left: Box::new(left), - op: *op, - right: Box::new(right), - }, + _ => Expr::BinaryExpr { left, op, right }, + }, + Expr::Not(inner) => { + // Not(Not(expr)) --> expr + if let Expr::Not(negated_inner) = *inner { + *negated_inner + } else { + Expr::Not(inner) + } } - } - Expr::Not(expr) => match &**expr { - Expr::Not(inner) => optimize_expr(&inner, schemas)?, - _ => Expr::Not(Box::new(optimize_expr(&expr, schemas)?)), - }, - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - // recurse into CASE WHEN condition expressions - Expr::Case { - expr: match expr { - Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), - None => None, - }, - when_then_expr: when_then_expr - .iter() - .map(|(when, then)| { - Ok(( - Box::new(optimize_expr(when, schemas)?), - Box::new(optimize_expr(then, schemas)?), - )) - }) - .collect::>()?, - else_expr: match else_expr { - Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), - None => None, - }, + expr => { + // no rewrite possible + expr } - } - Expr::Alias(expr, name) => { - Expr::Alias(Box::new(optimize_expr(expr, schemas)?), name.clone()) - } - Expr::Negative(expr) => Expr::Negative(Box::new(optimize_expr(expr, schemas)?)), - Expr::InList { - expr, - list, - negated, - } => Expr::InList { - expr: Box::new(optimize_expr(expr, schemas)?), - list: list - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - negated: *negated, - }, - Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(optimize_expr(expr, schemas)?)), - Expr::IsNull(expr) => Expr::IsNull(Box::new(optimize_expr(expr, schemas)?)), - Expr::Cast { expr, data_type } => Expr::Cast { - expr: Box::new(optimize_expr(expr, schemas)?), - data_type: data_type.clone(), - }, - Expr::Between { - expr, - negated, - low, - high, - } => Expr::Between { - expr: Box::new(optimize_expr(expr, schemas)?), - negated: *negated, - low: Box::new(optimize_expr(low, schemas)?), - high: Box::new(optimize_expr(high, schemas)?), - }, - Expr::ScalarFunction { fun, args } => Expr::ScalarFunction { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - }, - Expr::ScalarUDF { fun, args } => Expr::ScalarUDF { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - }, - Expr::AggregateFunction { - fun, - args, - distinct, - } => Expr::AggregateFunction { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - distinct: *distinct, - }, - Expr::AggregateUDF { fun, args } => Expr::AggregateUDF { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - }, - Expr::Sort { - expr, - asc, - nulls_first, - } => Expr::Sort { - expr: Box::new(optimize_expr(expr, schemas)?), - asc: *asc, - nulls_first: *nulls_first, - }, - Expr::Column { .. } - | Expr::ScalarVariable { .. } - | Expr::Literal { .. } - | Expr::Wildcard => e.clone(), - }) + }; + Ok(new_expr) + } } #[cfg(test)] @@ -331,8 +238,12 @@ mod tests { #[test] fn optimize_expr_not_not() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; + assert_eq!( - optimize_expr(&col("c2").not().not().not(), &[&schema])?, + (col("c2").not().not().not()).rewrite(&mut rewriter)?, col("c2").not(), ); @@ -342,34 +253,32 @@ mod tests { #[test] fn optimize_expr_null_comparision() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; // x = null is always null assert_eq!( - optimize_expr(&lit(true).eq(lit(ScalarValue::Boolean(None))), &[&schema])?, + (lit(true).eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); // null != null is always null assert_eq!( - optimize_expr( - &lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))), - &[&schema], - )?, + (lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))) + .rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); // x != null is always null assert_eq!( - optimize_expr( - &col("c2").not_eq(lit(ScalarValue::Boolean(None))), - &[&schema], - )?, + (col("c2").not_eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - optimize_expr(&lit(ScalarValue::Boolean(None)).eq(col("c2")), &[&schema])?, + (lit(ScalarValue::Boolean(None)).eq(col("c2"))).rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); @@ -379,29 +288,27 @@ mod tests { #[test] fn optimize_expr_eq() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; + assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // true = ture -> true - assert_eq!( - optimize_expr(&lit(true).eq(lit(true)), &[&schema])?, - lit(true), - ); + assert_eq!((lit(true).eq(lit(true))).rewrite(&mut rewriter)?, lit(true),); // true = false -> false assert_eq!( - optimize_expr(&lit(true).eq(lit(false)), &[&schema])?, + (lit(true).eq(lit(false))).rewrite(&mut rewriter)?, lit(false), ); // c2 = true -> c2 - assert_eq!( - optimize_expr(&col("c2").eq(lit(true)), &[&schema])?, - col("c2"), - ); + assert_eq!((col("c2").eq(lit(true))).rewrite(&mut rewriter)?, col("c2"),); // c2 = false => !c2 assert_eq!( - optimize_expr(&col("c2").eq(lit(false)), &[&schema])?, + (col("c2").eq(lit(false))).rewrite(&mut rewriter)?, col("c2").not(), ); @@ -411,6 +318,9 @@ mod tests { #[test] fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; // When one of the operand is not of boolean type, folding the other boolean constant will // change return type of expression to non-boolean. @@ -420,24 +330,24 @@ mod tests { // don't fold c1 = true assert_eq!( - optimize_expr(&col("c1").eq(lit(true)), &[&schema])?, + (col("c1").eq(lit(true))).rewrite(&mut rewriter)?, col("c1").eq(lit(true)), ); // don't fold c1 = false assert_eq!( - optimize_expr(&col("c1").eq(lit(false)), &[&schema],)?, + (col("c1").eq(lit(false))).rewrite(&mut rewriter)?, col("c1").eq(lit(false)), ); // test constant operands assert_eq!( - optimize_expr(&lit(1).eq(lit(true)), &[&schema],)?, + (lit(1).eq(lit(true))).rewrite(&mut rewriter)?, lit(1).eq(lit(true)), ); assert_eq!( - optimize_expr(&lit("a").eq(lit(false)), &[&schema],)?, + (lit("a").eq(lit(false))).rewrite(&mut rewriter)?, lit("a").eq(lit(false)), ); @@ -447,28 +357,32 @@ mod tests { #[test] fn optimize_expr_not_eq() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; + assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // c2 != true -> !c2 assert_eq!( - optimize_expr(&col("c2").not_eq(lit(true)), &[&schema])?, + (col("c2").not_eq(lit(true))).rewrite(&mut rewriter)?, col("c2").not(), ); // c2 != false -> c2 assert_eq!( - optimize_expr(&col("c2").not_eq(lit(false)), &[&schema])?, + (col("c2").not_eq(lit(false))).rewrite(&mut rewriter)?, col("c2"), ); // test constant assert_eq!( - optimize_expr(&lit(true).not_eq(lit(true)), &[&schema])?, + (lit(true).not_eq(lit(true))).rewrite(&mut rewriter)?, lit(false), ); assert_eq!( - optimize_expr(&lit(true).not_eq(lit(false)), &[&schema])?, + (lit(true).not_eq(lit(false))).rewrite(&mut rewriter)?, lit(true), ); @@ -478,29 +392,32 @@ mod tests { #[test] fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; // when one of the operand is not of boolean type, folding the other boolean constant will // change return type of expression to non-boolean. assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); assert_eq!( - optimize_expr(&col("c1").not_eq(lit(true)), &[&schema])?, + (col("c1").not_eq(lit(true))).rewrite(&mut rewriter)?, col("c1").not_eq(lit(true)), ); assert_eq!( - optimize_expr(&col("c1").not_eq(lit(false)), &[&schema])?, + (col("c1").not_eq(lit(false))).rewrite(&mut rewriter)?, col("c1").not_eq(lit(false)), ); // test constants assert_eq!( - optimize_expr(&lit(1).not_eq(lit(true)), &[&schema])?, + (lit(1).not_eq(lit(true))).rewrite(&mut rewriter)?, lit(1).not_eq(lit(true)), ); assert_eq!( - optimize_expr(&lit("a").not_eq(lit(false)), &[&schema],)?, + (lit("a").not_eq(lit(false))).rewrite(&mut rewriter)?, lit("a").not_eq(lit(false)), ); @@ -510,19 +427,20 @@ mod tests { #[test] fn optimize_expr_case_when_then_else() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; assert_eq!( - optimize_expr( - &Box::new(Expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(col("c2").not_eq(lit(false))), - Box::new(lit("ok").eq(lit(true))), - )], - else_expr: Some(Box::new(col("c2").eq(lit(true)))), - }), - &[&schema], - )?, + (Box::new(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(col("c2").not_eq(lit(false))), + Box::new(lit("ok").eq(lit(true))), + )], + else_expr: Some(Box::new(col("c2").eq(lit(true)))), + })) + .rewrite(&mut rewriter)?, Expr::Case { expr: None, when_then_expr: vec![( @@ -627,7 +545,7 @@ mod tests { let expected = "\ Projection: #a\ - \n Filter: NOT NOT #b\ + \n Filter: #b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected);