diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 698ece4f9e69f..4a71dba663b38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral -import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -54,6 +54,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) + case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) case p: LogicalPlan => p transformExpressions { case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) case cw @ CaseWhen(branches, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 6fc31c94e47eb..00433a5490574 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -49,6 +49,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) } test("Not expected type - replaceNullWithFalse") { @@ -66,6 +67,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested expressions in branches of If") { @@ -76,6 +78,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace null in elseValue of CaseWhen") { @@ -87,6 +90,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) + testUpdate(originalCond, expectedCond) } test("replace null in branch values of CaseWhen") { @@ -97,6 +101,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside CaseWhen") { @@ -114,6 +119,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) + testUpdate(originalCond, expectedCond) } test("replace null in complex CaseWhen expressions") { @@ -134,6 +140,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) + testUpdate(originalCond, expectedCond) } test("replace null in Or") { @@ -142,6 +149,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) + testUpdate(originalCond, expectedCond) } test("replace null in And") { @@ -149,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested And/Or expressions") { @@ -158,6 +167,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace null in And inside branches of If") { @@ -168,6 +178,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside And") { @@ -180,6 +191,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside another If") { @@ -190,6 +202,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("replace null in CaseWhen inside another CaseWhen") { @@ -198,6 +211,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) + testUpdate(originalCond, expectedCond = FalseLiteral) } test("inability to replace null in non-boolean branches of If") { @@ -211,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) testDelete(originalCond = condition, expectedCond = condition) + testUpdate(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean values of CaseWhen") { @@ -226,6 +241,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) testDelete(originalCond = condition, expectedCond = condition) + testUpdate(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -239,6 +255,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) testDelete(originalCond = condition, expectedCond = condition) + testUpdate(originalCond = condition, expectedCond = condition) } test("replace null in If used as a join condition") { @@ -374,6 +391,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test((rel, expr) => DeleteFromTable(rel, Some(expr)), originalCond, expectedCond) } + private def testUpdate(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond) + } + private def testHigherOrderFunc( argument: Expression, createExpr: (Expression, Expression) => Expression,