Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
ReplaceNullWithFalse,
PruneFilters,
EliminateSorts,
SimplifyCasts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] {
flattenConcats(concat)
}
}

/**
* A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations.
*
* This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates
* in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions.
*
* For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
*
* Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
* this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
* `Filter(FalseLiteral)`.
*
* As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can
* benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))`
* can be simplified into `Project(Literal(2))`.
*
* As a result, many unnecessary computations can be removed in the query optimization phase.
*/
object ReplaceNullWithFalse extends Rule[LogicalPlan] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us move it to a new file. The file is growing too big.


def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case p: LogicalPlan => p transformExpressions {
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
case cw @ CaseWhen(branches, _) =>
val newBranches = branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> value
}
cw.copy(branches = newBranches)
}
}

/**
* Recursively replaces `Literal(null, _)` with `FalseLiteral`.
*
* Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
* an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`.
Copy link
Contributor

@cloud-fan cloud-fan Oct 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it more general? I think the expected expression is:

  1. It's NullIntolerant. If any child is null, it will be null.
  2. it has a null child.

so I would write something like

def replaceNullWithFalse(e: Expression): Expression = e match {
  case _ if alwaysNull(e) => FalseLiteral
  case And...
  case Or...
  case _ => e
}

def alwaysNull(e: Expression): Boolean = e match {
  case Literal(null, _) => true
  case n: NullIntolerant => n.children.exists(alwaysNull) 
  case _ => false
}

Copy link
Contributor Author

@aokolnychyi aokolnychyi Oct 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your snippet because it is clean. We also considered a similar approach.

  1. Unfortunately, it does not handle nested If/CaseWhen expressions as they are not NullIntolerant. For example, cases like If(If(a > 1, FalseLiteral, Literal(null, _)), 1, 2) will not be optimized if we remove branches for If/CaseWhen.
  2. If we just add one more brach to handle all NullIntolerant expressions, I am not sure it will give a lot of benefits as those expressions are transformed into Literal(null, _) by NullPropagation and we operate in the same batch.
  3. As @gatorsmile said, we should be really careful. Generalization might be tricky. For example, Not is NullIntolerant. Not(null) is transformed into null by NullPropagation. We need to ensure that we do not replace null inside Not and do not convert Not(null) into Not(FalseLiteral).

Therefore, the intention was to keep things simple to be safe.

*/
private def replaceNullWithFalse(e: Expression): Expression = e match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsNull(Literal(null, _)) => IsNull(FalseLiteral)

Will this be a problem for this change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only do the replacements when 1) within Join or Filter such as Filter(If(cond, FalseLiteral, Literal(null, _))), or 2) If(Literal(null, _), trueValue, falseValue).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, that's the reason why we don't use transformExpressionsDown. We will stop the replacement as soon as we hit an expression that is not CaseWhen, If, And, Or or Literal(null, _). Therefore, If(IsNull(Literal(null, _))) won't be transformed.

case cw: CaseWhen if cw.dataType == BooleanType =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When cw.dataType != BooleanType, we can still do replaceNullWithFalse(cond), don't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is also covered and tested in "replace null in conditions of CaseWhen", "replace null in conditions of CaseWhen inside another CaseWhen".

val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When i.dataType != BooleanType, we still can do replaceNullWithFalse(pred), don't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is handled in apply and tested in "replace null in predicates of If", "replace null in predicates of If inside another If"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if I got you correctly here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general rule for LogicalPlan at apply looks at predicate of If, no matter its dataType is BooleanType or not.

But in replaceNullWithFalse, the rule for If only works if its dataType is BooleanType. "replace null in predicates of If inside another If" is a such case. The If inside another If is of BooleanType. If the inside If is not of BooleanType, this rule doesn't work. And I think it should be ok to replace the null with false when it is not boolean type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see. replaceNullWithFalse should only work in boolean type cases. Then I think we are fine with it.

case And(left, right) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to be careful here. null && fales is false, null || true is true. Please take a look at #22702

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate a bit more on null && false?

I had in mind AND(true, null) and OR(false, null), which are tricky. For example, if we use AND(true, null) in SELECT, we will get null. However, if we use it inside Filter or predicate of If, it will be semantically equivalent to false (e.g., If$eval). Therefore, the proposed rule has a limited scope. I explored the source code & comments in And/Or to come up with an edge case that wouldn’t work. I could not find such a case. To me, it seems safe because the rule is applied only to expressions that evaluate to false if the underlying expression is null (i.e., conditions in Filter/Join, predicates in If, conditions in CaseWhen).

Please, let me know if you have a particular case to test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a particular case, this is just to double check that these corner cases are considered. I think we are fine now :)

And(replaceNullWithFalse(left), replaceNullWithFalse(right))
case Or(left, right) =>
Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
case Literal(null, _) => FalseLiteral
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, for safety, we should check the data types.

case _ => e
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{BooleanType, IntegerType}

class ReplaceNullWithFalseSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Replace null literals", FixedPoint(10),
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyConditionals,
ReplaceNullWithFalse) :: Nil
}

private val testRelation = LocalRelation('i.int, 'b.boolean)
private val anotherTestRelation = LocalRelation('d.int)

test("replace null inside filter and join conditions") {
testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
}

test("replace null in branches of If") {
val originalCond = If(
UnresolvedAttribute("i") > Literal(10),
FalseLiteral,
Literal(null, BooleanType))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace nulls in nested expressions in branches of If") {
val originalCond = If(
UnresolvedAttribute("i") > Literal(10),
TrueLiteral && Literal(null, BooleanType),
UnresolvedAttribute("b") && Literal(null, BooleanType))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in elseValue of CaseWhen") {
val branches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
val originalCond = CaseWhen(branches, Literal(null, BooleanType))
val expectedCond = CaseWhen(branches, FalseLiteral)
testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
}

test("replace null in branch values of CaseWhen") {
val branches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType),
(UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
val originalCond = CaseWhen(branches, Literal(null))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in branches of If inside CaseWhen") {
val originalBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) ->
If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral),
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
val originalCond = CaseWhen(originalBranches)

val expectedBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
}

test("replace null in complex CaseWhen expressions") {
val originalBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(Literal(6) <= Literal(1)) -> FalseLiteral,
(Literal(4) === Literal(5)) -> FalseLiteral,
(UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType),
(Literal(4) === Literal(4)) -> TrueLiteral)
val originalCond = CaseWhen(originalBranches)

val expectedBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
TrueLiteral -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
}

test("replace null in Or") {
val originalCond = Or(UnresolvedAttribute("b"), Literal(null))
val expectedCond = UnresolvedAttribute("b")
testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
}

test("replace null in And") {
val originalCond = And(UnresolvedAttribute("b"), Literal(null))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace nulls in nested And/Or expressions") {
val originalCond = And(
And(UnresolvedAttribute("b"), Literal(null)),
Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null)))))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in And inside branches of If") {
val originalCond = If(
UnresolvedAttribute("i") > Literal(10),
FalseLiteral,
And(UnresolvedAttribute("b"), Literal(null, BooleanType)))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in branches of If inside And") {
val originalCond = And(
UnresolvedAttribute("b"),
If(
UnresolvedAttribute("i") > Literal(10),
Literal(null),
And(FalseLiteral, UnresolvedAttribute("b"))))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in branches of If inside another If") {
val originalCond = If(
If(UnresolvedAttribute("b"), Literal(null), FalseLiteral),
TrueLiteral,
Literal(null))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in CaseWhen inside another CaseWhen") {
val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null))
val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null))
testFilter(originalCond, expectedCond = FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("inability to replace null in non-boolean branches of If") {
val condition = If(
UnresolvedAttribute("i") > Literal(10),
Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3)),
FalseLiteral)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
}

test("inability to replace null in non-boolean values of CaseWhen") {
val nestedCaseWhen = CaseWhen(
Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)),
Literal(null, IntegerType))
val branchValue = If(
Literal(2) === nestedCaseWhen,
TrueLiteral,
FalseLiteral)
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
val condition = CaseWhen(branches)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
}

test("inability to replace null in non-boolean branches of If inside another If") {
val condition = If(
Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3)),
TrueLiteral,
FalseLiteral)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
}

test("replace null in If used as a join condition") {
// this test is only for joins as the condition involves columns from different relations
val originalCond = If(
UnresolvedAttribute("d") > UnresolvedAttribute("i"),
Literal(null),
FalseLiteral)
testJoin(originalCond, expectedCond = FalseLiteral)
}

test("replace null in CaseWhen used as a join condition") {
// this test is only for joins as the condition involves columns from different relations
val originalBranches = Seq(
(UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null),
(UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)

val expectedBranches = Seq(
(UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral,
(UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)

testJoin(
originalCond = CaseWhen(originalBranches, FalseLiteral),
expectedCond = CaseWhen(expectedBranches, FalseLiteral))
}

test("inability to replace null in CaseWhen inside EqualTo used as a join condition") {
// this test is only for joins as the condition involves columns from different relations
val branches = Seq(
(UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType),
(UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral)
testJoin(originalCond = condition, expectedCond = condition)
}

test("replace null in predicates of If") {
val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null))
testProjection(
originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
expectedExpr = Literal(1).as("out"))
}

test("replace null in predicates of If inside another If") {
val predicate = If(
And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)),
TrueLiteral,
FalseLiteral)
testProjection(
originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
expectedExpr = Literal(1).as("out"))
}

test("inability to replace null in non-boolean expressions inside If predicates") {
val predicate = GreaterThan(
UnresolvedAttribute("i"),
If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
val column = If(predicate, Literal(5), Literal(1)).as("out")
testProjection(originalExpr = column, expectedExpr = column)
}

test("replace null in conditions of CaseWhen") {
val branches = Seq(
And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5))
testProjection(
originalExpr = CaseWhen(branches, Literal(2)).as("out"),
expectedExpr = Literal(2).as("out"))
}

test("replace null in conditions of CaseWhen inside another CaseWhen") {
val nestedCaseWhen = CaseWhen(
Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)),
Literal(2))
val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1))
testProjection(
originalExpr = CaseWhen(branches).as("out"),
expectedExpr = Literal(1).as("out"))
}

test("inability to replace null in non-boolean exprs inside CaseWhen conditions") {
val condition = GreaterThan(
UnresolvedAttribute("i"),
If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out")
testProjection(originalExpr = column, expectedExpr = column)
}

private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
test((rel, exp) => rel.where(exp), originalCond, expectedCond)
}

private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = {
test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond)
}

private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = {
test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
}

private def test(
func: (LogicalPlan, Expression) => LogicalPlan,
originalExpr: Expression,
expectedExpr: Expression): Unit = {

val originalPlan = func(testRelation, originalExpr).analyze
val optimizedPlan = Optimize.execute(originalPlan)
val expectedPlan = func(testRelation, expectedExpr).analyze
comparePlans(optimizedPlan, expectedPlan)
}
}
Loading