-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25860][SQL] Replace Literal(null, _) with FalseLiteral whenever possible #22857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] { | |
| flattenConcats(concat) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. | ||
| * | ||
| * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates | ||
| * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. | ||
| * | ||
| * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. | ||
| * | ||
| * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; | ||
| * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually | ||
| * `Filter(FalseLiteral)`. | ||
| * | ||
| * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can | ||
| * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` | ||
| * can be simplified into `Project(Literal(2))`. | ||
| * | ||
| * As a result, many unnecessary computations can be removed in the query optimization phase. | ||
| */ | ||
| object ReplaceNullWithFalse extends Rule[LogicalPlan] { | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) | ||
| case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) | ||
| case p: LogicalPlan => p transformExpressions { | ||
| case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) | ||
| case cw @ CaseWhen(branches, _) => | ||
| val newBranches = branches.map { case (cond, value) => | ||
| replaceNullWithFalse(cond) -> value | ||
| } | ||
| cw.copy(branches = newBranches) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Recursively replaces `Literal(null, _)` with `FalseLiteral`. | ||
| * | ||
| * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit | ||
| * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make it more general? I think the expected expression is:
so I would write something like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like your snippet because it is clean. We also considered a similar approach.
Therefore, the intention was to keep things simple to be safe. |
||
| */ | ||
| private def replaceNullWithFalse(e: Expression): Expression = e match { | ||
|
||
| case cw: CaseWhen if cw.dataType == BooleanType => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case is also covered and tested in |
||
| val newBranches = cw.branches.map { case (cond, value) => | ||
| replaceNullWithFalse(cond) -> replaceNullWithFalse(value) | ||
| } | ||
| val newElseValue = cw.elseValue.map(replaceNullWithFalse) | ||
| CaseWhen(newBranches, newElseValue) | ||
| case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => | ||
| If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case is handled in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if I got you correctly here
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The general rule for But in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, I see. |
||
| case And(left, right) => | ||
|
||
| And(replaceNullWithFalse(left), replaceNullWithFalse(right)) | ||
| case Or(left, right) => | ||
| Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) | ||
| case Literal(null, _) => FalseLiteral | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, for safety, we should check the data types. |
||
| case _ => e | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,323 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} | ||
| import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} | ||
| import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
| import org.apache.spark.sql.types.{BooleanType, IntegerType} | ||
|
|
||
| class ReplaceNullWithFalseSuite extends PlanTest { | ||
|
|
||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
| Batch("Replace null literals", FixedPoint(10), | ||
| NullPropagation, | ||
| ConstantFolding, | ||
| BooleanSimplification, | ||
| SimplifyConditionals, | ||
| ReplaceNullWithFalse) :: Nil | ||
| } | ||
|
|
||
| private val testRelation = LocalRelation('i.int, 'b.boolean) | ||
| private val anotherTestRelation = LocalRelation('d.int) | ||
|
|
||
| test("replace null inside filter and join conditions") { | ||
| testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) | ||
| testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in branches of If") { | ||
| val originalCond = If( | ||
| UnresolvedAttribute("i") > Literal(10), | ||
| FalseLiteral, | ||
| Literal(null, BooleanType)) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace nulls in nested expressions in branches of If") { | ||
| val originalCond = If( | ||
| UnresolvedAttribute("i") > Literal(10), | ||
| TrueLiteral && Literal(null, BooleanType), | ||
| UnresolvedAttribute("b") && Literal(null, BooleanType)) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in elseValue of CaseWhen") { | ||
| val branches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, | ||
| (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) | ||
| val originalCond = CaseWhen(branches, Literal(null, BooleanType)) | ||
| val expectedCond = CaseWhen(branches, FalseLiteral) | ||
| testFilter(originalCond, expectedCond) | ||
| testJoin(originalCond, expectedCond) | ||
| } | ||
|
|
||
| test("replace null in branch values of CaseWhen") { | ||
| val branches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType), | ||
| (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) | ||
| val originalCond = CaseWhen(branches, Literal(null)) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in branches of If inside CaseWhen") { | ||
| val originalBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> | ||
| If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral), | ||
| (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) | ||
| val originalCond = CaseWhen(originalBranches) | ||
|
|
||
| val expectedBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, | ||
| (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) | ||
| val expectedCond = CaseWhen(expectedBranches) | ||
|
|
||
| testFilter(originalCond, expectedCond) | ||
| testJoin(originalCond, expectedCond) | ||
| } | ||
|
|
||
| test("replace null in complex CaseWhen expressions") { | ||
| val originalBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, | ||
| (Literal(6) <= Literal(1)) -> FalseLiteral, | ||
| (Literal(4) === Literal(5)) -> FalseLiteral, | ||
| (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType), | ||
| (Literal(4) === Literal(4)) -> TrueLiteral) | ||
| val originalCond = CaseWhen(originalBranches) | ||
|
|
||
| val expectedBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, | ||
| (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, | ||
| TrueLiteral -> TrueLiteral) | ||
| val expectedCond = CaseWhen(expectedBranches) | ||
|
|
||
| testFilter(originalCond, expectedCond) | ||
| testJoin(originalCond, expectedCond) | ||
| } | ||
|
|
||
| test("replace null in Or") { | ||
| val originalCond = Or(UnresolvedAttribute("b"), Literal(null)) | ||
| val expectedCond = UnresolvedAttribute("b") | ||
| testFilter(originalCond, expectedCond) | ||
| testJoin(originalCond, expectedCond) | ||
| } | ||
|
|
||
| test("replace null in And") { | ||
| val originalCond = And(UnresolvedAttribute("b"), Literal(null)) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace nulls in nested And/Or expressions") { | ||
| val originalCond = And( | ||
| And(UnresolvedAttribute("b"), Literal(null)), | ||
| Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in And inside branches of If") { | ||
| val originalCond = If( | ||
| UnresolvedAttribute("i") > Literal(10), | ||
| FalseLiteral, | ||
| And(UnresolvedAttribute("b"), Literal(null, BooleanType))) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in branches of If inside And") { | ||
| val originalCond = And( | ||
| UnresolvedAttribute("b"), | ||
| If( | ||
| UnresolvedAttribute("i") > Literal(10), | ||
| Literal(null), | ||
| And(FalseLiteral, UnresolvedAttribute("b")))) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in branches of If inside another If") { | ||
| val originalCond = If( | ||
| If(UnresolvedAttribute("b"), Literal(null), FalseLiteral), | ||
| TrueLiteral, | ||
| Literal(null)) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in CaseWhen inside another CaseWhen") { | ||
| val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null)) | ||
| val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) | ||
| testFilter(originalCond, expectedCond = FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("inability to replace null in non-boolean branches of If") { | ||
| val condition = If( | ||
| UnresolvedAttribute("i") > Literal(10), | ||
| Literal(5) > If( | ||
| UnresolvedAttribute("i") === Literal(15), | ||
| Literal(null, IntegerType), | ||
| Literal(3)), | ||
| FalseLiteral) | ||
| testFilter(originalCond = condition, expectedCond = condition) | ||
| testJoin(originalCond = condition, expectedCond = condition) | ||
| } | ||
|
|
||
| test("inability to replace null in non-boolean values of CaseWhen") { | ||
| val nestedCaseWhen = CaseWhen( | ||
| Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)), | ||
| Literal(null, IntegerType)) | ||
| val branchValue = If( | ||
| Literal(2) === nestedCaseWhen, | ||
| TrueLiteral, | ||
| FalseLiteral) | ||
| val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) | ||
| val condition = CaseWhen(branches) | ||
| testFilter(originalCond = condition, expectedCond = condition) | ||
| testJoin(originalCond = condition, expectedCond = condition) | ||
| } | ||
|
|
||
| test("inability to replace null in non-boolean branches of If inside another If") { | ||
| val condition = If( | ||
| Literal(5) > If( | ||
| UnresolvedAttribute("i") === Literal(15), | ||
| Literal(null, IntegerType), | ||
| Literal(3)), | ||
| TrueLiteral, | ||
| FalseLiteral) | ||
| testFilter(originalCond = condition, expectedCond = condition) | ||
| testJoin(originalCond = condition, expectedCond = condition) | ||
| } | ||
|
|
||
| test("replace null in If used as a join condition") { | ||
| // this test is only for joins as the condition involves columns from different relations | ||
| val originalCond = If( | ||
| UnresolvedAttribute("d") > UnresolvedAttribute("i"), | ||
| Literal(null), | ||
| FalseLiteral) | ||
| testJoin(originalCond, expectedCond = FalseLiteral) | ||
| } | ||
|
|
||
| test("replace null in CaseWhen used as a join condition") { | ||
| // this test is only for joins as the condition involves columns from different relations | ||
| val originalBranches = Seq( | ||
| (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null), | ||
| (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) | ||
|
|
||
| val expectedBranches = Seq( | ||
| (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral, | ||
| (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) | ||
|
|
||
| testJoin( | ||
| originalCond = CaseWhen(originalBranches, FalseLiteral), | ||
| expectedCond = CaseWhen(expectedBranches, FalseLiteral)) | ||
| } | ||
|
|
||
| test("inability to replace null in CaseWhen inside EqualTo used as a join condition") { | ||
| // this test is only for joins as the condition involves columns from different relations | ||
| val branches = Seq( | ||
| (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType), | ||
| (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) | ||
| val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral) | ||
| testJoin(originalCond = condition, expectedCond = condition) | ||
| } | ||
|
|
||
| test("replace null in predicates of If") { | ||
| val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) | ||
| testProjection( | ||
| originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), | ||
| expectedExpr = Literal(1).as("out")) | ||
| } | ||
|
|
||
| test("replace null in predicates of If inside another If") { | ||
| val predicate = If( | ||
| And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)), | ||
| TrueLiteral, | ||
| FalseLiteral) | ||
| testProjection( | ||
| originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), | ||
| expectedExpr = Literal(1).as("out")) | ||
| } | ||
|
|
||
| test("inability to replace null in non-boolean expressions inside If predicates") { | ||
| val predicate = GreaterThan( | ||
| UnresolvedAttribute("i"), | ||
| If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) | ||
| val column = If(predicate, Literal(5), Literal(1)).as("out") | ||
| testProjection(originalExpr = column, expectedExpr = column) | ||
| } | ||
|
|
||
| test("replace null in conditions of CaseWhen") { | ||
| val branches = Seq( | ||
| And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5)) | ||
| testProjection( | ||
| originalExpr = CaseWhen(branches, Literal(2)).as("out"), | ||
| expectedExpr = Literal(2).as("out")) | ||
| } | ||
|
|
||
| test("replace null in conditions of CaseWhen inside another CaseWhen") { | ||
| val nestedCaseWhen = CaseWhen( | ||
| Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)), | ||
| Literal(2)) | ||
| val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1)) | ||
| testProjection( | ||
| originalExpr = CaseWhen(branches).as("out"), | ||
| expectedExpr = Literal(1).as("out")) | ||
| } | ||
|
|
||
| test("inability to replace null in non-boolean exprs inside CaseWhen conditions") { | ||
| val condition = GreaterThan( | ||
| UnresolvedAttribute("i"), | ||
| If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) | ||
| val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out") | ||
| testProjection(originalExpr = column, expectedExpr = column) | ||
| } | ||
|
|
||
| private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { | ||
| test((rel, exp) => rel.where(exp), originalCond, expectedCond) | ||
| } | ||
|
|
||
| private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = { | ||
| test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond) | ||
| } | ||
|
|
||
| private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = { | ||
| test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) | ||
| } | ||
|
|
||
| private def test( | ||
| func: (LogicalPlan, Expression) => LogicalPlan, | ||
| originalExpr: Expression, | ||
| expectedExpr: Expression): Unit = { | ||
|
|
||
| val originalPlan = func(testRelation, originalExpr).analyze | ||
| val optimizedPlan = Optimize.execute(originalPlan) | ||
| val expectedPlan = func(testRelation, expectedExpr).analyze | ||
| comparePlans(optimizedPlan, expectedPlan) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us move it to a new file. The file is growing too big.