From d2c5e6666814439a049b4cb7a28ae4802b49d164 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 25 Nov 2018 19:00:27 -0800 Subject: [PATCH 1/5] fix --- .../ReplaceNullWithFalseInPredicate.scala | 106 ++++++++++++++++++ .../sql/catalyst/optimizer/expressions.scala | 66 ----------- 2 files changed, 106 insertions(+), 66 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala new file mode 100644 index 000000000000..a525988bcc87 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} +import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.BooleanType + + +/** + * A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search + * condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator + * "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is + * semantically equivalent to `FalseLiteral` when evaluating the whole search condition. + * + * Please note that FALSE and NULL are not exchangeable in most cases, when the search condition + * contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable + * in very limited cases. + * + * For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`. + * + * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; + * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually + * `Filter(FalseLiteral)`. + * + * Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch + * conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions. + * + * For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified + * into `Project(Literal(2))`. + */ +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) + case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case p: LogicalPlan => p transformExpressions { + case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) + case cw @ CaseWhen(branches, _) => + val newBranches = branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> value + } + cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) + } + } + + /** + * Recursively replaces `Literal(null, BooleanType)` with `FalseLiteral`. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or + * `Literal(null, BooleanType)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = { + if (e.dataType != BooleanType) { + e + } else { + e match { + case Literal(null, BooleanType) => + FalseLiteral + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case If(pred, trueVal, falseVal) => + If(replaceNullWithFalse(pred), + replaceNullWithFalse(trueVal), + replaceNullWithFalse(falseVal)) + case _ => e + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 354efd883f81..468a950fb108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -736,69 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } - -/** - * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. - * - * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates - * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. - * - * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. - * - * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; - * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually - * `Filter(FalseLiteral)`. - * - * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can - * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` - * can be simplified into `Project(Literal(2))`. - * - * As a result, many unnecessary computations can be removed in the query optimization phase. - */ -object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) - case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) - case p: LogicalPlan => p transformExpressions { - case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) - case cw @ CaseWhen(branches, _) => - val newBranches = branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> value - } - cw.copy(branches = newBranches) - case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - af.copy(function = newLambda) - case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - ae.copy(function = newLambda) - case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - mf.copy(function = newLambda) - } - } - - /** - * Recursively replaces `Literal(null, _)` with `FalseLiteral`. - * - * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit - * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. - */ - private def replaceNullWithFalse(e: Expression): Expression = e match { - case cw: CaseWhen if cw.dataType == BooleanType => - val newBranches = cw.branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> replaceNullWithFalse(value) - } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => - If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) - case And(left, right) => - And(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Or(left, right) => - Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Literal(null, _) => FalseLiteral - case _ => e - } -} From 6b6997d6c5eedb9a75af61345ae808c9d98e6f4d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 25 Nov 2018 19:18:05 -0800 Subject: [PATCH 2/5] fix --- .../catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index a525988bcc87..4357ceab2733 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -72,7 +72,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { } /** - * Recursively replaces `Literal(null, BooleanType)` with `FalseLiteral`. + * Recursively traverse the Boolean-type expression to replace + * `Literal(null, BooleanType)` with `FalseLiteral`, if possible. * * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or From e41681096867cbc6d2556da83ce733092d6df841 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 25 Nov 2018 22:26:29 -0800 Subject: [PATCH 3/5] fix the test case --- .../optimizer/ReplaceNullWithFalseInPredicateSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 3a9e6cae0fd8..8541c1031be5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -44,8 +44,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { - testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) - testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) } test("replace null in branches of If") { From 350142028a003e3729cb05f37983e77be548deff Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 26 Nov 2018 10:37:21 -0800 Subject: [PATCH 4/5] issue an exception --- .../ReplaceNullWithFalseInPredicate.scala | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 4357ceab2733..72a60f692ac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.Utils /** @@ -79,29 +80,31 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or * `Literal(null, BooleanType)`. */ - private def replaceNullWithFalse(e: Expression): Expression = { - if (e.dataType != BooleanType) { + private def replaceNullWithFalse(e: Expression): Expression = e match { + case Literal(null, BooleanType) => + FalseLiteral + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case e if e.dataType == BooleanType => e - } else { - e match { - case Literal(null, BooleanType) => - FalseLiteral - case And(left, right) => - And(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Or(left, right) => - Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case cw: CaseWhen => - val newBranches = cw.branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> replaceNullWithFalse(value) - } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - case If(pred, trueVal, falseVal) => - If(replaceNullWithFalse(pred), - replaceNullWithFalse(trueVal), - replaceNullWithFalse(falseVal)) - case _ => e + case e => + val message = "Expected a Boolean type expression in replaceNullWithFalse, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`." + if (Utils.isTesting) { + throw new IllegalArgumentException(message) + } else { + logWarning(message) + e } - } } } From 8b0401c4440136e33d4580a3f8da80164de3d4b4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 26 Nov 2018 14:41:56 -0800 Subject: [PATCH 5/5] added a test case --- .../optimizer/ReplaceNullWithFalseInPredicateSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 8541c1031be5..ee0d04da3e46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -48,6 +48,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) } + test("Not expected type - replaceNullWithFalse") { + val e = intercept[IllegalArgumentException] { + testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral) + }.getMessage + assert(e.contains("but got the type `int` in `CAST(NULL AS INT)")) + } + test("replace null in branches of If") { val originalCond = If( UnresolvedAttribute("i") > Literal(10),