Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive
import java.io.IOException
import java.util.Locale

import scala.collection.{immutable, mutable}

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.sql._
Expand Down Expand Up @@ -256,16 +258,76 @@ private[hive] trait HiveStrategies {
// Filter out all predicates that only deal with partition keys, these are given to the
// hive table scan operator to be used for partition pruning.
val partitionKeyIds = AttributeSet(relation.partitionCols)
val (pruningPredicates, otherPredicates) = predicates.partition { predicate =>
!predicate.references.isEmpty &&
predicate.references.subsetOf(partitionKeyIds)

case class PruningResult(pruning: Option[Expression], allChildrenCanPruning: Boolean)

// Add Or(And(_, _) And(_, _)) predicate to prunning partitions if And contains
// partition column and non-partition column. Assume p1/p2 is partition column and c1
// is non-partition column, the predicate "(p1 = 'a' and c1 = 1) or (p2 = 'b' and c1 = 3)"
// can extract predicate "(p1 = 'a') or (p2 = 'b')" to prunning partitions, in order to get
// the correct result, we must add the original predicate to outer filter. this can be
// exactly correct because the inner predicate contains outer predicate and outer predicate
// is original predicate
def canPruning(predicate: Expression): PruningResult = {
predicate match {
case And(left, right) =>
val leftResult = canPruning(left)
val rightResult = canPruning(right)

if (leftResult.pruning.isDefined && rightResult.pruning.isDefined) {
val pruning = Option(And(leftResult.pruning.get, rightResult.pruning.get))
val allChildrenCanPruning =
leftResult.allChildrenCanPruning && rightResult.allChildrenCanPruning
PruningResult(pruning, allChildrenCanPruning)
} else if (leftResult.pruning.isDefined) {
val pruning = Option(leftResult.pruning.get)
PruningResult(pruning, false)
} else if (rightResult.pruning.isDefined) {
val pruning = Option(rightResult.pruning.get)
PruningResult(pruning, false)
} else PruningResult(Option.empty, false)

case Or(left, right) =>
val leftResult = canPruning(left)
val rightResult = canPruning(right)

val pruning: Option[Expression] =
if (leftResult.pruning.isDefined && rightResult.pruning.isDefined) {
Option(Or(leftResult.pruning.get, rightResult.pruning.get))
} else Option.empty

val allChildrenCanPruning =
leftResult.allChildrenCanPruning && rightResult.allChildrenCanPruning

PruningResult(pruning, allChildrenCanPruning)

case _ => if (!predicate.references.isEmpty &&
predicate.references.subsetOf(partitionKeyIds)) {
PruningResult(Option(predicate), true)
} else PruningResult(Option.empty, false)
}
}

val pruningPredicates = mutable.ListBuffer[Expression]()
val otherPredicates = mutable.ListBuffer[Expression]()
predicates.foreach { predicate =>
// Find partition predicate that can pruning
val result = canPruning(predicate)
if (result.pruning.isDefined) {
pruningPredicates += result.pruning.get
}

// If any child can not pruning, add original predicate to outer filter
if (!result.allChildrenCanPruning) {
otherPredicates += predicate
}
}

pruneFilterProject(
projectList,
otherPredicates,
otherPredicates.to[immutable.Seq],
identity[Seq[Expression]],
HiveTableScanExec(_, relation, pruningPredicates)(sparkSession)) :: Nil
HiveTableScanExec(_, relation, pruningPredicates.to[immutable.Seq])(sparkSession)) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal, Or}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton}
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.util.Utils

class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestHiveSingleton {
Expand Down Expand Up @@ -173,6 +177,109 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH
}
}

test("SPARK-28983: HiveTableScans OR push down to partitioned table 1") {
val table = "table_with_partition"
val testSql =
s"""
|SELECT *
|FROM $table
|WHERE
|(p1='a')
|OR
|(p1='c')
""".stripMargin
planTest(table, testSql) {plan =>
assert(plan.isInstanceOf[HiveTableScanExec])
val scan = checkType(plan, classOf[HiveTableScanExec])
assert(scan.partitionPruningPred.size == 1)
compareExpression(scan.partitionPruningPred(0), Or(
EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")),
EqualTo(AttributeReference("p1", StringType, true)(), Literal("c"))
))
}
}

test("SPARK-28983: HiveTableScans OR push down to partitioned table 2") {
val table = "table_with_partition"
val testSql =
s"""
|SELECT *
|FROM $table
|WHERE
|(p1='a' and id=1)
|OR
|(p1='c')
""".stripMargin
planTest(table, testSql) {plan =>
val filter = checkType(plan, classOf[FilterExec])
val scan = checkType(filter.child, classOf[HiveTableScanExec])

compareExpression(filter.condition, Or(
and(
EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")),
EqualTo(AttributeReference("id", IntegerType, true)(), Literal(1))
),
EqualTo(AttributeReference("p1", StringType, true)(), Literal("c"))
))

assert(scan.partitionPruningPred.size == 1)

compareExpression(scan.partitionPruningPred(0), Or(
EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")),
EqualTo(AttributeReference("p1", StringType, true)(), Literal("c"))
))
}
}

test("SPARK-28983: HiveTableScans OR push down to partitioned table 3") {
val table = "table_with_partition"
val testSql =
s"""
|SELECT *
|FROM $table
|WHERE
|(p1='a' or p1='d')
|OR
|(p1='c')
""".stripMargin
planTest(table, testSql) {plan =>
val scan = checkType(plan, classOf[HiveTableScanExec])

assert(scan.partitionPruningPred.size == 1)
compareExpression(scan.partitionPruningPred(0), Or(
Or(
EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")),
EqualTo(AttributeReference("p1", StringType, true)(), Literal("d"))
),
EqualTo(AttributeReference("p1", StringType, true)(), Literal("c"))
))
}
}

test("SPARK-28983: HiveTableScans OR push down to partitioned table 4") {
val table = "table_with_partition"
val testSql =
s"""
|SELECT *
|FROM $table
|WHERE
|(p1='a')
|OR
|(id=5)
""".stripMargin
planTest(table, testSql) {plan =>
val filter = checkType(plan, classOf[FilterExec])
val scan = checkType(filter.child, classOf[HiveTableScanExec])

compareExpression(filter.condition, Or(
EqualTo(AttributeReference("p1", StringType, true)(), Literal("a")),
EqualTo(AttributeReference("id", IntegerType, true)(), Literal(5))
))

assert(scan.partitionPruningPred.size == 0)
}
}

test("HiveTableScanExec canonicalization for different orders of partition filters") {
val table = "hive_tbl_part"
withTable(table) {
Expand All @@ -187,9 +294,54 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH
}
}

private def planTest(table: String, testSql: String)
(func: SparkPlan => Unit): Unit = {
val view = "src"
withTempView(view) {
spark.range(1, 5).createOrReplaceTempView(view)
withTable(table) {
sql(
s"""
|CREATE TABLE $table(id int)
|PARTITIONED BY (p1 string, p2 string)
""".stripMargin)

sql(
s"""
|FROM $view v
|INSERT INTO TABLE $table
|partition(P1='a',p2='b')
|select v.id
|INSERT INTO TABLE $table
|PARTITION(p1='c',p2='d')
|select v.id
""".stripMargin)

val plan = sql(testSql).queryExecution.sparkPlan
func(plan)
}
}
}

private def getHiveTableScanExec(query: String): HiveTableScanExec = {
sql(query).queryExecution.sparkPlan.collectFirst {
case p: HiveTableScanExec => p
}.get
}

private class ExpressionComparator extends PlanTest {
def compare(e1: Expression, e2: Expression): Unit =
super.compareExpressions(e1, e2)
}

private def compareExpression(e1: Expression, e2: Expression): Unit
= new ExpressionComparator().compare(e1, e2)

private def checkType[T](any: Any, clazz: Class[T]): T = {
assert(clazz.isInstance(any), s"${any} is not instance of ${clazz}")
clazz.cast(any)
}

private def and(left: Expression, right: Expression) =
org.apache.spark.sql.catalyst.expressions.And(left, right)
}