Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, TrueLiteral, FalseLiteral) => cond
case If(cond, FalseLiteral, TrueLiteral) => Not(cond)
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite

test("Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a))

// Push down at most one not foldable expressions.
assertEquivalent(
Expand All @@ -67,7 +67,7 @@ class PushFoldableIntoBranchesSuite
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral))
GreaterThanOrEqual(Rand(1), Literal(0.5)))
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral))

Expand Down Expand Up @@ -102,8 +102,7 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
If(a, Literal(2.0), Literal(3.0)))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral),
If(a, FalseLiteral, TrueLiteral))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a))
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
Expand Down Expand Up @@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(2) === nestedCaseWhen,
TrueLiteral,
FalseLiteral)
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
val condition = CaseWhen(branches)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond =
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen)))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
}

test("inability to replace null in non-boolean branches of If inside another If") {
Expand All @@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(3)),
TrueLiteral,
FalseLiteral)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
val expectedCond = Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
}

test("replace null in If used as a join condition") {
Expand Down Expand Up @@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val lambda1 = LambdaFunction(
function = If(cond, Literal(null, BooleanType), TrueLiteral),
arguments = lambdaArgs)
// the optimized lambda body is: if(arg > 0, false, true)
// the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
val lambda2 = LambdaFunction(
function = If(cond, FalseLiteral, TrueLiteral),
function = LessThanOrEqual(condArg, Literal(0)),
arguments = lambdaArgs)
testProjection(
originalExpr = createExpr(argument, lambda1) as 'x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,20 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
}
}

test("SPARK-33845: remove unnecessary if when the outputs are boolean type") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
IsNotNull(UnresolvedAttribute("a")))
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
IsNull(UnresolvedAttribute("a")))

assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
GreaterThan(Rand(0), UnresolvedAttribute("a")))
assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
}
}