Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ case class Not(child: Expression)

override def inputTypes: Seq[DataType] = Seq(BooleanType)

// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
// | TRUE | FALSE |
// | FALSE | TRUE |
// | UNKNOWN | UNKNOWN |
// +---------+-----------+
protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean]

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -374,6 +381,13 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with

override def sqlOperator: String = "AND"

// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | UNKNOWN | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == false) {
Expand Down Expand Up @@ -437,6 +451,13 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P

override def sqlOperator: String = "OR"

// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | FALSE | TRUE | FALSE | UNKNOWN |
// | UNKNOWN | TRUE | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == true) {
Expand Down Expand Up @@ -560,6 +581,13 @@ case class EqualTo(left: Expression, right: Expression)

override def symbol: String = "="

// +---------+---------+---------+---------+
// | = | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | TRUE | UNKNOWN |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -597,6 +625,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp

override def nullable: Boolean = false

// +---------+---------+---------+---------+
// | <=> | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | TRUE | UNKNOWN |
// | UNKNOWN | UNKNOWN | UNKNOWN | TRUE |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
val input2 = right.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,37 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a

case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)

case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
// The following optimizations are applicable only when the operands are not nullable,
// since the three-value logic of AND and OR are different in NULL handling.
// See the chart:
// +---------+---------+---------+---------+
// | operand | operand | OR | AND |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | TRUE | FALSE | TRUE | FALSE |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+

// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)

// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)

// Common factor elimination for conjunction
case and @ (left And right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand Down Expand Up @@ -71,6 +71,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
comparePlans(actual, correctAnswer)
}

private def checkConditionInNotNullableRelation(
input: Expression, expected: Expression): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
val actual = Optimize.execute(plan)
val correctAnswer = testNotNullableRelationWithData.where(expected).analyze
comparePlans(actual, correctAnswer)
}

private def checkConditionInNotNullableRelation(
input: Expression, expected: LogicalPlan): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
Expand Down Expand Up @@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
'a === 'b || 'b > 3 && 'a > 3 && 'a < 5)
}

test("e && (!e || f)") {
checkCondition('e && (!'e || 'f ), 'e && 'f)
test("e && (!e || f) - not nullable") {
checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f)

checkCondition('e && ('f || !'e ), 'e && 'f)
checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f)

checkCondition((!'e || 'f ) && 'e, 'f && 'e)
checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e)

checkCondition(('f || !'e ) && 'e, 'f && 'e)
checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e)
}

test("a < 1 && (!(a < 1) || f)") {
checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)
test("e && (!e || f) - nullable") {
Seq ('e && (!'e || 'f ),
'e && ('f || !'e ),
(!'e || 'f ) && 'e,
('f || !'e ) && 'e,
'e || (!'e && 'f),
'e || ('f && !'e),
('e && 'f) || !'e,
('f && 'e) || !'e).foreach { expr =>
checkCondition(expr, expr)
}
}

checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)
test("a < 1 && (!(a < 1) || f) - not nullable") {
checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)

checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)

checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)

checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
}

test("a < 1 && ((a >= 1) || f)") {
checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)
test("a < 1 && ((a >= 1) || f) - not nullable") {
checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)

checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)

checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)

checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
}

test("DeMorgan's law") {
Expand Down Expand Up @@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze)
checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze)
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
comparePlans(actual, correctAnswer)
}

test("filter reduction - positive cases") {
val fields = Seq(
'col1NotNULL.boolean.notNull,
'col2NotNULL.boolean.notNull
)
val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) }

val exprs = Seq(
// actual expressions of the transformations: original -> transformed
(col1NotNULL && (!col1NotNULL || col2NotNULL)) -> (col1NotNULL && col2NotNULL),
(col1NotNULL && (col2NotNULL || !col1NotNULL)) -> (col1NotNULL && col2NotNULL),
((!col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
((col2NotNULL || !col1NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),

(col1NotNULL || (!col1NotNULL && col2NotNULL)) -> (col1NotNULL || col2NotNULL),
(col1NotNULL || (col2NotNULL && !col1NotNULL)) -> (col1NotNULL || col2NotNULL),
((!col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL),
((col2NotNULL && !col1NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL)
)

// check plans
for ((originalExpr, expectedExpr) <- exprs) {
assertEquivalent(originalExpr, expectedExpr)
}

// check evaluation
val binaryBooleanValues = Seq(true, false)
for (col1NotNULLVal <- binaryBooleanValues;
col2NotNULLVal <- binaryBooleanValues;
(originalExpr, expectedExpr) <- exprs) {
val inputRow = create_row(col1NotNULLVal, col2NotNULLVal)
val optimizedVal = evaluate(expectedExpr, inputRow)
checkEvaluation(originalExpr, optimizedVal, inputRow)
}
}
}