diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 18feb98519fbe..c69ce4422f54f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.Locale +import scala.collection.{immutable, mutable} + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql._ @@ -256,16 +258,76 @@ private[hive] trait HiveStrategies { // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. val partitionKeyIds = AttributeSet(relation.partitionCols) - val (pruningPredicates, otherPredicates) = predicates.partition { predicate => - !predicate.references.isEmpty && - predicate.references.subsetOf(partitionKeyIds) + + case class PruningResult(pruning: Option[Expression], allChildrenCanPruning: Boolean) + + // Add Or(And(_, _) And(_, _)) predicate to prunning partitions if And contains + // partition column and non-partition column. Assume p1/p2 is partition column and c1 + // is non-partition column, the predicate "(p1 = 'a' and c1 = 1) or (p2 = 'b' and c1 = 3)" + // can extract predicate "(p1 = 'a') or (p2 = 'b')" to prunning partitions, in order to get + // the correct result, we must add the original predicate to outer filter. this can be + // exactly correct because the inner predicate contains outer predicate and outer predicate + // is original predicate + def canPruning(predicate: Expression): PruningResult = { + predicate match { + case And(left, right) => + val leftResult = canPruning(left) + val rightResult = canPruning(right) + + if (leftResult.pruning.isDefined && rightResult.pruning.isDefined) { + val pruning = Option(And(leftResult.pruning.get, rightResult.pruning.get)) + val allChildrenCanPruning = + leftResult.allChildrenCanPruning && rightResult.allChildrenCanPruning + PruningResult(pruning, allChildrenCanPruning) + } else if (leftResult.pruning.isDefined) { + val pruning = Option(leftResult.pruning.get) + PruningResult(pruning, false) + } else if (rightResult.pruning.isDefined) { + val pruning = Option(rightResult.pruning.get) + PruningResult(pruning, false) + } else PruningResult(Option.empty, false) + + case Or(left, right) => + val leftResult = canPruning(left) + val rightResult = canPruning(right) + + val pruning: Option[Expression] = + if (leftResult.pruning.isDefined && rightResult.pruning.isDefined) { + Option(Or(leftResult.pruning.get, rightResult.pruning.get)) + } else Option.empty + + val allChildrenCanPruning = + leftResult.allChildrenCanPruning && rightResult.allChildrenCanPruning + + PruningResult(pruning, allChildrenCanPruning) + + case _ => if (!predicate.references.isEmpty && + predicate.references.subsetOf(partitionKeyIds)) { + PruningResult(Option(predicate), true) + } else PruningResult(Option.empty, false) + } + } + + val pruningPredicates = mutable.ListBuffer[Expression]() + val otherPredicates = mutable.ListBuffer[Expression]() + predicates.foreach { predicate => + // Find partition predicate that can pruning + val result = canPruning(predicate) + if (result.pruning.isDefined) { + pruningPredicates += result.pruning.get + } + + // If any child can not pruning, add original predicate to outer filter + if (!result.allChildrenCanPruning) { + otherPredicates += predicate + } } pruneFilterProject( projectList, - otherPredicates, + otherPredicates.to[immutable.Seq], identity[Seq[Expression]], - HiveTableScanExec(_, relation, pruningPredicates)(sparkSession)) :: Nil + HiveTableScanExec(_, relation, pruningPredicates.to[immutable.Seq])(sparkSession)) :: Nil case _ => Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 3f9bb8de42e09..d41d5c099c033 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,11 +18,15 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestHiveSingleton { @@ -173,6 +177,109 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } } + test("SPARK-28983: HiveTableScans OR push down to partitioned table 1") { + val table = "table_with_partition" + val testSql = + s""" + |SELECT * + |FROM $table + |WHERE + |(p1='a') + |OR + |(p1='c') + """.stripMargin + planTest(table, testSql) {plan => + assert(plan.isInstanceOf[HiveTableScanExec]) + val scan = checkType(plan, classOf[HiveTableScanExec]) + assert(scan.partitionPruningPred.size == 1) + compareExpression(scan.partitionPruningPred(0), Or( + EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")), + EqualTo(AttributeReference("p1", StringType, true)(), Literal("c")) + )) + } + } + + test("SPARK-28983: HiveTableScans OR push down to partitioned table 2") { + val table = "table_with_partition" + val testSql = + s""" + |SELECT * + |FROM $table + |WHERE + |(p1='a' and id=1) + |OR + |(p1='c') + """.stripMargin + planTest(table, testSql) {plan => + val filter = checkType(plan, classOf[FilterExec]) + val scan = checkType(filter.child, classOf[HiveTableScanExec]) + + compareExpression(filter.condition, Or( + and( + EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")), + EqualTo(AttributeReference("id", IntegerType, true)(), Literal(1)) + ), + EqualTo(AttributeReference("p1", StringType, true)(), Literal("c")) + )) + + assert(scan.partitionPruningPred.size == 1) + + compareExpression(scan.partitionPruningPred(0), Or( + EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")), + EqualTo(AttributeReference("p1", StringType, true)(), Literal("c")) + )) + } + } + + test("SPARK-28983: HiveTableScans OR push down to partitioned table 3") { + val table = "table_with_partition" + val testSql = + s""" + |SELECT * + |FROM $table + |WHERE + |(p1='a' or p1='d') + |OR + |(p1='c') + """.stripMargin + planTest(table, testSql) {plan => + val scan = checkType(plan, classOf[HiveTableScanExec]) + + assert(scan.partitionPruningPred.size == 1) + compareExpression(scan.partitionPruningPred(0), Or( + Or( + EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")), + EqualTo(AttributeReference("p1", StringType, true)(), Literal("d")) + ), + EqualTo(AttributeReference("p1", StringType, true)(), Literal("c")) + )) + } + } + + test("SPARK-28983: HiveTableScans OR push down to partitioned table 4") { + val table = "table_with_partition" + val testSql = + s""" + |SELECT * + |FROM $table + |WHERE + |(p1='a') + |OR + |(id=5) + """.stripMargin + planTest(table, testSql) {plan => + val filter = checkType(plan, classOf[FilterExec]) + val scan = checkType(filter.child, classOf[HiveTableScanExec]) + + compareExpression(filter.condition, Or( + EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")), + EqualTo(AttributeReference("id", IntegerType, true)(), Literal(5)) + )) + + assert(scan.partitionPruningPred.size == 0) + } + } + test("HiveTableScanExec canonicalization for different orders of partition filters") { val table = "hive_tbl_part" withTable(table) { @@ -187,9 +294,54 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } } + private def planTest(table: String, testSql: String) + (func: SparkPlan => Unit): Unit = { + val view = "src" + withTempView(view) { + spark.range(1, 5).createOrReplaceTempView(view) + withTable(table) { + sql( + s""" + |CREATE TABLE $table(id int) + |PARTITIONED BY (p1 string, p2 string) + """.stripMargin) + + sql( + s""" + |FROM $view v + |INSERT INTO TABLE $table + |partition(P1='a',p2='b') + |select v.id + |INSERT INTO TABLE $table + |PARTITION(p1='c',p2='d') + |select v.id + """.stripMargin) + + val plan = sql(testSql).queryExecution.sparkPlan + func(plan) + } + } + } + private def getHiveTableScanExec(query: String): HiveTableScanExec = { sql(query).queryExecution.sparkPlan.collectFirst { case p: HiveTableScanExec => p }.get } + + private class ExpressionComparator extends PlanTest { + def compare(e1: Expression, e2: Expression): Unit = + super.compareExpressions(e1, e2) + } + + private def compareExpression(e1: Expression, e2: Expression): Unit + = new ExpressionComparator().compare(e1, e2) + + private def checkType[T](any: Any, clazz: Class[T]): T = { + assert(clazz.isInstance(any), s"${any} is not instance of ${clazz}") + clazz.cast(any) + } + + private def and(left: Expression, right: Expression) = + org.apache.spark.sql.catalyst.expressions.And(left, right) }