diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 5629b7289422..f8037588fa71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -263,10 +263,15 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral - case a And b if Not(a).semanticEquals(b) => FalseLiteral - case a Or b if Not(a).semanticEquals(b) => TrueLiteral - case a And b if a.semanticEquals(Not(b)) => FalseLiteral - case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral) + case a And b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral) + + case a Or b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral) + case a Or b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral) case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 653c07f1835c..6cd1108eef33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -37,6 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, + SimplifyConditionals, BooleanSimplification, PruneFilters) :: Nil } @@ -48,6 +50,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.output, Seq(Row(1, 2, 3, "abc")) ) + val testNotNullableRelation = LocalRelation('a.int.notNull, 'b.int.notNull, 'c.int.notNull, + 'd.string.notNull, 'e.boolean.notNull, 'f.boolean.notNull, 'g.boolean.notNull, + 'h.boolean.notNull) + + val testNotNullableRelationWithData = LocalRelation.fromExternalRows( + testNotNullableRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { val plan = testRelationWithData.where(input).analyze val actual = Optimize.execute(plan) @@ -61,6 +71,13 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { comparePlans(actual, correctAnswer) } + private def checkConditionInNotNullableRelation( + input: Expression, expected: LogicalPlan): Unit = { + val plan = testNotNullableRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + test("a && a => a") { checkCondition(Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) checkCondition(Literal(1) < 'a && Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) @@ -174,10 +191,30 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { } test("Complementation Laws") { - checkCondition('a && !'a, testRelation) - checkCondition(!'a && 'a, testRelation) + checkConditionInNotNullableRelation('e && !'e, testNotNullableRelation) + checkConditionInNotNullableRelation(!'e && 'e, testNotNullableRelation) + + checkConditionInNotNullableRelation('e || !'e, testNotNullableRelationWithData) + checkConditionInNotNullableRelation(!'e || 'e, testNotNullableRelationWithData) + } + + test("Complementation Laws - null handling") { + checkCondition('e && !'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + checkCondition(!'e && 'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + + checkCondition('e || !'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + checkCondition(!'e || 'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + } + + test("Complementation Laws - negative case") { + checkCondition('e && !'f, testRelationWithData.where('e && !'f).analyze) + checkCondition(!'f && 'e, testRelationWithData.where(!'f && 'e).analyze) - checkCondition('a || !'a, testRelationWithData) - checkCondition(!'a || 'a, testRelationWithData) + checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze) + checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 435b887cb3c7..279b7b8d49f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2569,4 +2569,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) check(lit(2).cast("int"), $"c" =!= 2, Seq()) } + + test("SPARK-25402 Null handling in BooleanSimplification") { + val schema = StructType.fromDDL("a boolean, b int") + val rows = Seq(Row(null, 1)) + + val rdd = sparkContext.parallelize(rows) + val df = spark.createDataFrame(rdd, schema) + + checkAnswer(df.where("(NOT a) OR a"), Seq.empty) + } }