Skip to content

Commit 4ab3a58

Browse files
committed
Fixes test failure, adds more tests
1 parent 5d54349 commit 4ab3a58

File tree

4 files changed

+32
-13
lines changed

4 files changed

+32
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import scala.collection.immutable.HashSet
2120
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2221
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2322
import org.apache.spark.sql.catalyst.types.BooleanType
@@ -48,6 +47,14 @@ trait PredicateHelper {
4847
}
4948
}
5049

50+
protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
51+
condition match {
52+
case Or(cond1, cond2) =>
53+
splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
54+
case other => other :: Nil
55+
}
56+
}
57+
5158
/**
5259
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
5360
* can be used to determine when is is acceptable to move expression evaluation within a query

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ object CombineFilters extends Rule[LogicalPlan] {
349349
}
350350

351351
/**
352-
* Normalizes conjuctions and disjunctions to eliminate common factors.
352+
* Normalizes conjunctions and disjunctions to eliminate common factors.
353353
*/
354354
object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper {
355355
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -358,17 +358,23 @@ object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper {
358358
}
359359

360360
def normalizedPredicate(predicate: Expression): Seq[Expression] = predicate match {
361-
// a || a => a
362-
case Or(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil
363-
// a && a => a
364-
case And(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil
361+
// a && a && a ... => a
362+
case p @ And(e, _) if splitConjunctivePredicates(p).distinct.size == 1 => e :: Nil
363+
364+
// a || a || a ... => a
365+
case p @ Or(e, _) if splitDisjunctivePredicates(p).distinct.size == 1 => e :: Nil
366+
365367
// (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...)
366368
case Or(lhs, rhs) =>
367369
val lhsSet = splitConjunctivePredicates(lhs).toSet
368370
val rhsSet = splitConjunctivePredicates(rhs).toSet
369-
val commonPredicates = lhsSet & rhsSet
370-
val otherPredicates = (lhsSet | rhsSet) &~ commonPredicates
371-
otherPredicates.reduceOption(Or).getOrElse(Literal(true)) :: commonPredicates.toList
371+
val common = lhsSet.intersect(rhsSet)
372+
373+
(lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
374+
.reduceOption(Or)
375+
.map(_ :: common.toList)
376+
.getOrElse(common.toList)
377+
372378
case _ => predicate :: Nil
373379
}
374380
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ class NormalizeFiltersSuite extends PlanTest {
5252

5353
test("a && a => a") {
5454
checkExpression('a === 1 && 'a === 1, 'a === 1)
55+
checkExpression('a === 1 && 'a === 1 && 'a === 1, 'a === 1)
5556
}
5657

5758
test("a || a => a") {
5859
checkExpression('a === 1 || 'a === 1, 'a === 1)
60+
checkExpression('a === 1 || 'a === 1 || 'a === 1, 'a === 1)
5961
}
6062

61-
test("(a && b) || (a && c)") {
63+
test("(a && b) || (a && c) => a && (b || c)") {
6264
checkExpression(
6365
('a === 1 && 'a < 10) || ('a > 2 && 'a === 1),
6466
('a === 1) && ('a < 10 || 'a > 2))

sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,20 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
105105

106106
test(query) {
107107
val schemaRdd = sql(query)
108-
assertResult(expectedQueryResult.toArray, "Wrong query result") {
108+
val queryExecution = schemaRdd.queryExecution
109+
110+
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
109111
schemaRdd.collect().map(_.head).toArray
110112
}
111113

112114
val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
113115
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
114116
}.head
115117

116-
assert(readBatches === expectedReadBatches, "Wrong number of read batches")
117-
assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions")
118+
assert(readBatches === expectedReadBatches, s"Wrong number of read batches: $queryExecution")
119+
assert(
120+
readPartitions === expectedReadPartitions,
121+
s"Wrong number of read partitions: $queryExecution")
118122
}
119123
}
120124
}

0 commit comments

Comments
 (0)