Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType}
import org.apache.spark.sql.types.{AtomicType, ByteType, DataType, IntegerType, IntegralType, LongType, ShortType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -646,16 +646,16 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
}

object ExtractableLiteral {
def unapply(expr: Expression): Option[String] = expr match {
def unapply(expr: Expression): Option[(String, DataType)] = expr match {
case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs.
case Literal(value, _: IntegralType) => Some(value.toString)
case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString))
case l @ Literal(value, _: IntegralType) => Some(value.toString, l.dataType)
case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString), StringType)
case _ => None
}
}

object ExtractableLiterals {
def unapply(exprs: Seq[Expression]): Option[Seq[String]] = {
def unapply(exprs: Seq[Expression]): Option[Seq[(String, DataType)]] = {
// SPARK-24879: The Hive metastore filter parser does not support "null", but we still want
// to push down as many predicates as we can while still maintaining correctness.
// In SQL, the `IN` expression evaluates as follows:
Expand All @@ -682,15 +682,15 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
}

object ExtractableValues {
private lazy val valueToLiteralString: PartialFunction[Any, String] = {
case value: Byte => value.toString
case value: Short => value.toString
case value: Int => value.toString
case value: Long => value.toString
case value: UTF8String => quoteStringLiteral(value.toString)
private lazy val valueToLiteralString: PartialFunction[Any, (String, DataType)] = {
case value: Byte => (value.toString, ByteType)
case value: Short => (value.toString, ShortType)
case value: Int => (value.toString, IntegerType)
case value: Long => (value.toString, LongType)
case value: UTF8String => (quoteStringLiteral(value.toString), StringType)
}

def unapply(values: Set[Any]): Option[Seq[String]] = {
def unapply(values: Set[Any]): Option[Seq[(String, DataType)]] = {
val extractables = values.toSeq.map(valueToLiteralString.lift)
if (extractables.nonEmpty && extractables.forall(_.isDefined)) {
Some(extractables.map(_.get))
Expand Down Expand Up @@ -726,33 +726,79 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled

object ExtractAttribute {
def unapply(expr: Expression): Option[Attribute] = {
def unapply(expr: Expression): Option[(Attribute, DataType)] = {
expr match {
case attr: Attribute => Some(attr)
case attr: Attribute => Some(attr, attr.dataType)
case Cast(child @ AtomicType(), dt: AtomicType, _)
if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child)
case _ => None
}
}
}

def compatibleTypes(dt1: Any, dt2: Any): Boolean =
(dt1, dt2) match {
case (_: IntegralType, _: IntegralType) => true
case (_: StringType, _: StringType) => true
case _ => false
}

def compatibleTypesIn(dt1: Any, dts: Seq[Any]): Boolean = {
dts.forall(compatibleTypes(dt1, _))
}

def fixValue(quotedValue: String, desiredType: DataType): Option[Any] = try {
val value = quotedValue.init.tail // remove leading and trailing quotes
desiredType match {
case LongType =>
Some(value.toLong)
case IntegerType =>
Some(value.toInt)
case ShortType =>
Some(value.toShort)
case ByteType =>
Some(value.toByte)
}
} catch {
case _: NumberFormatException =>
None
}

def convert(expr: Expression): Option[String] = expr match {
case In(ExtractAttribute(SupportedAttribute(name)), ExtractableLiterals(values))
if useAdvanced =>
case In(ExtractAttribute(SupportedAttribute(name), dt1), ExtractableLiterals(valsAndDts))
if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
val values = valsAndDts.map(_._1)
Some(convertInToOr(name, values))

case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values))
if useAdvanced =>
case InSet(ExtractAttribute(SupportedAttribute(name), dt1), ExtractableValues(valsAndDts))
if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
val values = valsAndDts.map(_._1)
Some(convertInToOr(name, values))

case op @ SpecialBinaryComparison(
ExtractAttribute(SupportedAttribute(name)), ExtractableLiteral(value)) =>
ExtractAttribute(SupportedAttribute(name), dt1), ExtractableLiteral(value, dt2))
if compatibleTypes(dt1, dt2) =>
Some(s"$name ${op.symbol} $value")

case op @ SpecialBinaryComparison(
ExtractableLiteral(value), ExtractAttribute(SupportedAttribute(name))) =>
ExtractAttribute(SupportedAttribute(name), dt1), ExtractableLiteral(rawValue, dt2))
if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] =>
fixValue(rawValue, dt1).map { value =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, will this change semantics? suppose we have cast(b as string) < '012' where b is 11. Before the conversion this will evaluate to false but after it will evaluate to true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should probably ignore any literal strings with leading zeros.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should do it in UnwrapCastInBinaryComparison so that it can not only be used by Hive but also other data sources.

Whatever makes sense. There is some (long-time) ongoing work with TypeCoercion (#22038) that fixes a few of these cases. But if if that goes through and we can close the gap with the others, that would be fine. I am probably not in a position to provide much help in the optimizer code (at this point).

s"$name ${op.symbol} $value"
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an equivalent "attempt to correct" for In and Inset, just for binary comparisons. In the case of In and Inset, if the datatypes are not compatible, I just drop the filter (which is what would have happened before SPARK-22384)


case op @ SpecialBinaryComparison(
ExtractableLiteral(value, dt2), ExtractAttribute(SupportedAttribute(name), dt1))
if compatibleTypes(dt1, dt2) =>
Some(s"$value ${op.symbol} $name")

case op @ SpecialBinaryComparison(
ExtractableLiteral(rawValue, dt2), ExtractAttribute(SupportedAttribute(name), dt1))
if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] =>
fixValue(rawValue, dt1).map { value =>
s"$value ${op.symbol} $name"
}

case And(expr1, expr2) if useAdvanced =>
val converted = convert(expr1) ++ convert(expr2)
if (converted.isEmpty) {
Expand Down Expand Up @@ -795,6 +841,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {

val partitions =
if (filter.isEmpty) {
logDebug(s"Falling back to getting all partitions")
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
} else {
logDebug(s"Hive metastore filter is '$filter'.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,49 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
"1 = intcol")

filterTest("int and string filter",
(Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
(Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", StringType)) :: Nil,
"1 = intcol and \"a\" = strcol")

filterTest("int and int column/string literal filter",
(a("intcol1", IntegerType) === Literal(1)) ::
(a("intcol2", IntegerType) === (Literal("a"))) :: Nil,
"intcol1 = 1")

filterTest("int and int column/string literal filter with conversion",
(a("intcol1", IntegerType) === Literal(1)) ::
(a("intcol2", IntegerType) === (Literal("00002"))) :: Nil,
"intcol1 = 1 and intcol2 = 2")

filterTest("int and int column/string literal filter backwards",
(Literal(1) === a("intcol1", IntegerType)) ::
(Literal("a") === a("intcol2", IntegerType)) :: Nil,
"1 = intcol1")

filterTest("int and int column/string literal filter backwards with conversion",
(Literal(1) === a("intcol1", IntegerType)) ::
(Literal("00002") === a("intcol2", IntegerType)) :: Nil,
"1 = intcol1 and 2 = intcol2")

filterTest("int filter with in",
(a("intcol", IntegerType) in (Literal(1), Literal(2))) :: Nil,
"(intcol = 1 or intcol = 2)")

filterTest("int/string filter with in",
(a("intcol", IntegerType) in (Literal("1"), Literal("2"))) :: Nil,
"")

filterTest("int filter with inset",
(a("intcol", IntegerType) in ((0 to 11).map(Literal(_)): _*)) :: Nil,
"(" + (0 to 11).map(x => s"intcol = $x").mkString(" or ") + ")")

filterTest("int/string filter with inset",
(a("intcol", IntegerType) in ((0 to 11).map(x => Literal(x.toString)): _*)) :: Nil,
"")

filterTest("string filter with in",
(a("strcol", StringType) in (Literal("1"), Literal("2"))) :: Nil,
"(strcol = \"1\" or strcol = \"2\")")

filterTest("skip varchar",
(Literal("") === a("varchar", StringType)) :: Nil,
"")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {
}
}

test("SPARK-33098: Don't push partition filter with mismatched datatypes to metastore") {
withTable("t") {
sql("create table t (a int) partitioned by (b int) stored as parquet")

// There are only two test cases because TestHiveSparkSession sets
// hive.metastore.integral.jdo.pushdown=true, which, as a side effect, prevents
// the Metaexception for most of the problem cases. Only the two cases below
// would still throw a MetaException when Hive is configured this way
sql("select cast(b as string) b from t").filter("b > '1'").collect
sql("select * from t where cast(b as string) > '1'").collect
}
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: HiveTableScanExec => p
Expand Down