Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ case class Not(child: Expression)

override def inputTypes: Seq[DataType] = Seq(BooleanType)

// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
// | TRUE | FALSE |
// | FALSE | TRUE |
// | UNKNOWN | UNKNOWN |
// +---------+-----------+
protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean]

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -404,6 +411,13 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with

override def sqlOperator: String = "AND"

// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | UNKNOWN | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == false) {
Expand Down Expand Up @@ -467,6 +481,13 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P

override def sqlOperator: String = "OR"

// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | FALSE | TRUE | FALSE | UNKNOWN |
// | UNKNOWN | TRUE | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == true) {
Expand Down Expand Up @@ -590,6 +611,13 @@ case class EqualTo(left: Expression, right: Expression)

override def symbol: String = "="

// +---------+---------+---------+---------+
// | = | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | TRUE | UNKNOWN |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -627,6 +655,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp

override def nullable: Boolean = false

// +---------+---------+---------+---------+
// | <=> | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | TRUE | UNKNOWN |
// | UNKNOWN | UNKNOWN | UNKNOWN | TRUE |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
val input2 = right.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,31 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a

case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)

case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
// The following optimization is applicable only when the operands are nullable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: only when the operands are not nullable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

// since the three-value logic of AND and OR are different in NULL handling.
// See the chart:
// +---------+---------+---------+---------+
// | p | q | p OR q | p AND q |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | TRUE | FALSE | TRUE | FALSE |
// | TRUE | UNKNOWN | TRUE | UNKNOWN |
// | FALSE | TRUE | TRUE | FALSE |
// | FALSE | FALSE | FALSE | FALSE |
// | FALSE | UNKNOWN | UNKNOWN | FALSE |
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
Copy link
Contributor

@cloud-fan cloud-fan Oct 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming a is null, then b is also null.
If c is null: a And (b Or c) -> null, And(a, c) -> null
If c is true: a And (b Or c) -> null, And(a, c) -> null
if c is false: a And (b Or c) -> null, And(a, c) -> false

So yes this is a bug, and we should rewrite it to If(IsNull(a), null, And(a, c)), because if a is null, the result is always null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is complicated, shall we put a comment to explain it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after more thoughts, a And (b Or c) should be better than If(IsNull(a), null, And(a, c)), as it's more likely to get pushed down to data source, so the changes here LGTM

case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
case (a Or b) And c if !a.nullable && a.semanticEquals(Not(c)) => And(b, c)
case (a Or b) And c if !b.nullable && b.semanticEquals(Not(c)) => And(a, c)

case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these shouldn't be a problem, since if a is true, then a Or b is true, regardless of b's value/nullability, isn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is when a is null, c is true

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now, sorry. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it is the other case where the change is not needed, right?
a And (b Or c) -> And(a, c) when a is null, And(a, c) returns null (I got a bit confused earlier, sorry).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when a is null, And(a, c) returns null

This is not always the case, null && false is false

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes you're right, this might be a problem indeed if the expression is inside a not. Sorry, thanks.

case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
case (a And b) Or c if !a.nullable && a.semanticEquals(Not(c)) => Or(b, c)
case (a And b) Or c if !b.nullable && b.semanticEquals(Not(c)) => Or(a, c)

// Common factor elimination for conjunction
case and @ (left And right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand Down Expand Up @@ -71,6 +71,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
comparePlans(actual, correctAnswer)
}

private def checkConditionInNotNullableRelation(
input: Expression, expected: Expression): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
val actual = Optimize.execute(plan)
val correctAnswer = testNotNullableRelationWithData.where(expected).analyze
comparePlans(actual, correctAnswer)
}

private def checkConditionInNotNullableRelation(
input: Expression, expected: LogicalPlan): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
Expand Down Expand Up @@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
'a === 'b || 'b > 3 && 'a > 3 && 'a < 5)
}

test("e && (!e || f)") {
checkCondition('e && (!'e || 'f ), 'e && 'f)
test("e && (!e || f) - not nullable") {
checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f)

checkCondition('e && ('f || !'e ), 'e && 'f)
checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f)

checkCondition((!'e || 'f ) && 'e, 'f && 'e)
checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e)

checkCondition(('f || !'e ) && 'e, 'f && 'e)
checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e)
}

test("a < 1 && (!(a < 1) || f)") {
checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)
test("e && (!e || f) - nullable") {
Seq ('e && (!'e || 'f ),
'e && ('f || !'e ),
(!'e || 'f ) && 'e,
('f || !'e ) && 'e,
'e || (!'e && 'f),
'e || ('f && !'e),
('e && 'f) || !'e,
('f && 'e) || !'e).foreach { expr =>
checkCondition(expr, expr)
}
}

checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)
test("a < 1 && (!(a < 1) || f) - not nullable") {
checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)

checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)

checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)

checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
}

test("a < 1 && ((a >= 1) || f)") {
checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)
test("a < 1 && ((a >= 1) || f) - not nullable") {
checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)

checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)

checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)

checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
}

test("DeMorgan's law") {
Expand Down Expand Up @@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze)
checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze)
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
comparePlans(actual, correctAnswer)
}

test("filter reduction - positive cases") {
val fields = Seq(
'col1NotNULL.boolean.notNull,
'col2NotNULL.boolean.notNull
)
val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) }

val exprs = Seq(
// actual expressions of the transformations: original -> transformed
(col1NotNULL && (!col1NotNULL || col2NotNULL)) -> (col1NotNULL && col2NotNULL),
(col1NotNULL && (col2NotNULL || !col1NotNULL)) -> (col1NotNULL && col2NotNULL),
((!col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
((col2NotNULL || !col1NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),

(col1NotNULL || (!col1NotNULL && col2NotNULL)) -> (col1NotNULL || col2NotNULL),
(col1NotNULL || (col2NotNULL && !col1NotNULL)) -> (col1NotNULL || col2NotNULL),
((!col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL),
((col2NotNULL && !col1NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL)
)

// check plans
for ((originalExpr, expectedExpr) <- exprs) {
assertEquivalent(originalExpr, expectedExpr)
}

// check evaluation
val binaryBooleanValues = Seq(true, false)
for (col1NotNULLVal <- binaryBooleanValues;
col2NotNULLVal <- binaryBooleanValues;
(originalExpr, expectedExpr) <- exprs) {
val inputRow = create_row(col1NotNULLVal, col2NotNULLVal)
val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
checkEvaluation(originalExpr, optimizedVal, inputRow)
}
}
}
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -2567,4 +2568,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {

checkAnswer(df.where("(NOT a) OR a"), Seq.empty)
}

test("SPARK-25714 Null handling in BooleanSimplification") {
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ConvertToLocalRelation.ruleName) {
val df = Seq(("abc", 1), (null, 3)).toDF("col1", "col2")
checkAnswer(
df.filter("col1 = 'abc' OR (col1 != 'abc' AND col2 == 3)"),
Row ("abc", 1))
}
}
}