From be7a23645a3a48f3a8afd9ea00ee118e764e8e8b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 13 Oct 2018 20:52:26 -0700 Subject: [PATCH] fix. --- .../sql/catalyst/expressions/predicates.scala | 35 ++++++ .../sql/catalyst/optimizer/expressions.scala | 40 +++++-- .../BooleanSimplificationSuite.scala | 111 ++++++++++++++---- 3 files changed, 153 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 02fb262ec845..ffc07d92896f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,6 +120,13 @@ case class Not(child: Expression) override def inputTypes: Seq[DataType] = Seq(BooleanType) + // +---------+-----------+ + // | CHILD | NOT CHILD | + // +---------+-----------+ + // | TRUE | FALSE | + // | FALSE | TRUE | + // | UNKNOWN | UNKNOWN | + // +---------+-----------+ protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -331,6 +338,13 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with override def sqlOperator: String = "AND" + // +---------+---------+---------+---------+ + // | AND | TRUE | FALSE | UNKNOWN | + // +---------+---------+---------+---------+ + // | TRUE | TRUE | FALSE | UNKNOWN | + // | FALSE | FALSE | FALSE | FALSE | + // | UNKNOWN | UNKNOWN | FALSE | UNKNOWN | + // +---------+---------+---------+---------+ override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == false) { @@ -433,6 +447,13 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P override def sqlOperator: String = "OR" + // +---------+---------+---------+---------+ + // | OR | TRUE | FALSE | UNKNOWN | + // +---------+---------+---------+---------+ + // | TRUE | TRUE | TRUE | TRUE | + // | FALSE | TRUE | FALSE | UNKNOWN | + // | UNKNOWN | TRUE | UNKNOWN | UNKNOWN | + // +---------+---------+---------+---------+ override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == true) { @@ -583,6 +604,13 @@ case class EqualTo(left: Expression, right: Expression) override def symbol: String = "=" + // +---------+---------+---------+---------+ + // | = | TRUE | FALSE | UNKNOWN | + // +---------+---------+---------+---------+ + // | TRUE | TRUE | FALSE | UNKNOWN | + // | FALSE | FALSE | TRUE | UNKNOWN | + // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN | + // +---------+---------+---------+---------+ protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -618,6 +646,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def nullable: Boolean = false + // +---------+---------+---------+---------+ + // | <=> | TRUE | FALSE | UNKNOWN | + // +---------+---------+---------+---------+ + // | TRUE | TRUE | FALSE | UNKNOWN | + // | FALSE | FALSE | TRUE | UNKNOWN | + // | UNKNOWN | UNKNOWN | UNKNOWN | TRUE | + // +---------+---------+---------+---------+ override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index cf26ab0bc69e..3bf4de921b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -167,15 +167,37 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a - case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c) - case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b) - case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c) - case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c) - - case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c) - case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b) - case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c) - case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c) + // The following optimizations are applicable only when the operands are not nullable, + // since the three-value logic of AND and OR are different in NULL handling. + // See the chart: + // +---------+---------+---------+---------+ + // | operand | operand | OR | AND | + // +---------+---------+---------+---------+ + // | TRUE | TRUE | TRUE | TRUE | + // | TRUE | FALSE | TRUE | FALSE | + // | FALSE | FALSE | FALSE | FALSE | + // | UNKNOWN | TRUE | TRUE | UNKNOWN | + // | UNKNOWN | FALSE | UNKNOWN | FALSE | + // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN | + // +---------+---------+---------+---------+ + + // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable. + case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c) + // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable. + case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b) + // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable. + case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c) + // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable. + case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c) + + // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable. + case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c) + // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable. + case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b) + // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable. + case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c) + // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable. + case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c) // Common factor elimination for conjunction case and @ (left And right) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index df2624eaba18..34e24648122e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.BooleanType -class BooleanSimplificationSuite extends PlanTest with PredicateHelper { +class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -71,6 +71,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { comparePlans(actual, correctAnswer) } + private def checkConditionInNotNullableRelation( + input: Expression, expected: Expression): Unit = { + val plan = testNotNullableRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + val correctAnswer = testNotNullableRelationWithData.where(expected).analyze + comparePlans(actual, correctAnswer) + } + private def checkConditionInNotNullableRelation( input: Expression, expected: LogicalPlan): Unit = { val plan = testNotNullableRelationWithData.where(input).analyze @@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } - test("a && (!a || b)") { - checkCondition('a && (!'a || 'b ), 'a && 'b) + test("e && (!e || f) - not nullable") { + checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f) - checkCondition('a && ('b || !'a ), 'a && 'b) + checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f) - checkCondition((!'a || 'b ) && 'a, 'b && 'a) + checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e) - checkCondition(('b || !'a ) && 'a, 'b && 'a) + checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e) } - test("a < 1 && (!(a < 1) || b)") { - checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) + test("e && (!e || f) - nullable") { + Seq ('e && (!'e || 'f ), + 'e && ('f || !'e ), + (!'e || 'f ) && 'e, + ('f || !'e ) && 'e, + 'e || (!'e && 'f), + 'e || ('f && !'e), + ('e && 'f) || !'e, + ('f && 'e) || !'e).foreach { expr => + checkCondition(expr, expr) + } + } - checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) + test("a < 1 && (!(a < 1) || f) - not nullable") { + checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f) + checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f) - checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) + checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f) + checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f) - checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) + checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f) + checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f) + + checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f) + checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f) } - test("a < 1 && ((a >= 1) || b)") { - checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + test("a < 1 && ((a >= 1) || f) - not nullable") { + checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f) + checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f) - checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f) + checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f) - checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f) + checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f) + checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f) } test("DeMorgan's law") { @@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze) checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze) } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + test("filter reduction - positive cases") { + val fields = Seq( + 'col1NotNULL.boolean.notNull, + 'col2NotNULL.boolean.notNull + ) + val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) } + + val exprs = Seq( + // actual expressions of the transformations: original -> transformed + (col1NotNULL && (!col1NotNULL || col2NotNULL)) -> (col1NotNULL && col2NotNULL), + (col1NotNULL && (col2NotNULL || !col1NotNULL)) -> (col1NotNULL && col2NotNULL), + ((!col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL), + ((col2NotNULL || !col1NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL), + + (col1NotNULL || (!col1NotNULL && col2NotNULL)) -> (col1NotNULL || col2NotNULL), + (col1NotNULL || (col2NotNULL && !col1NotNULL)) -> (col1NotNULL || col2NotNULL), + ((!col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL), + ((col2NotNULL && !col1NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL) + ) + + // check plans + for ((originalExpr, expectedExpr) <- exprs) { + assertEquivalent(originalExpr, expectedExpr) + } + + // check evaluation + val binaryBooleanValues = Seq(true, false) + for (col1NotNULLVal <- binaryBooleanValues; + col2NotNULLVal <- binaryBooleanValues; + (originalExpr, expectedExpr) <- exprs) { + val inputRow = create_row(col1NotNULLVal, col2NotNULLVal) + val optimizedVal = evaluate(expectedExpr, inputRow) + checkEvaluation(originalExpr, optimizedVal, inputRow) + } + } }