diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 4ab0599e4477..890d7810507b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, ByteType, DataType, IntegerType, IntegralType, LongType, ShortType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -646,16 +646,16 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } object ExtractableLiteral { - def unapply(expr: Expression): Option[String] = expr match { + def unapply(expr: Expression): Option[(String, DataType)] = expr match { case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs. - case Literal(value, _: IntegralType) => Some(value.toString) - case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) + case l @ Literal(value, _: IntegralType) => Some(value.toString, l.dataType) + case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString), StringType) case _ => None } } object ExtractableLiterals { - def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { + def unapply(exprs: Seq[Expression]): Option[Seq[(String, DataType)]] = { // SPARK-24879: The Hive metastore filter parser does not support "null", but we still want // to push down as many predicates as we can while still maintaining correctness. // In SQL, the `IN` expression evaluates as follows: @@ -682,15 +682,15 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } object ExtractableValues { - private lazy val valueToLiteralString: PartialFunction[Any, String] = { - case value: Byte => value.toString - case value: Short => value.toString - case value: Int => value.toString - case value: Long => value.toString - case value: UTF8String => quoteStringLiteral(value.toString) + private lazy val valueToLiteralString: PartialFunction[Any, (String, DataType)] = { + case value: Byte => (value.toString, ByteType) + case value: Short => (value.toString, ShortType) + case value: Int => (value.toString, IntegerType) + case value: Long => (value.toString, LongType) + case value: UTF8String => (quoteStringLiteral(value.toString), StringType) } - def unapply(values: Set[Any]): Option[Seq[String]] = { + def unapply(values: Set[Any]): Option[Seq[(String, DataType)]] = { val extractables = values.toSeq.map(valueToLiteralString.lift) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) @@ -726,9 +726,9 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled object ExtractAttribute { - def unapply(expr: Expression): Option[Attribute] = { + def unapply(expr: Expression): Option[(Attribute, DataType)] = { expr match { - case attr: Attribute => Some(attr) + case attr: Attribute => Some(attr, attr.dataType) case Cast(child @ AtomicType(), dt: AtomicType, _) if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None @@ -736,23 +736,69 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } } + def compatibleTypes(dt1: Any, dt2: Any): Boolean = + (dt1, dt2) match { + case (_: IntegralType, _: IntegralType) => true + case (_: StringType, _: StringType) => true + case _ => false + } + + def compatibleTypesIn(dt1: Any, dts: Seq[Any]): Boolean = { + dts.forall(compatibleTypes(dt1, _)) + } + + def fixValue(quotedValue: String, desiredType: DataType): Option[Any] = try { + val value = quotedValue.init.tail // remove leading and trailing quotes + desiredType match { + case LongType => + Some(value.toLong) + case IntegerType => + Some(value.toInt) + case ShortType => + Some(value.toShort) + case ByteType => + Some(value.toByte) + } + } catch { + case _: NumberFormatException => + None + } + def convert(expr: Expression): Option[String] = expr match { - case In(ExtractAttribute(SupportedAttribute(name)), ExtractableLiterals(values)) - if useAdvanced => + case In(ExtractAttribute(SupportedAttribute(name), dt1), ExtractableLiterals(valsAndDts)) + if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) => + val values = valsAndDts.map(_._1) Some(convertInToOr(name, values)) - case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values)) - if useAdvanced => + case InSet(ExtractAttribute(SupportedAttribute(name), dt1), ExtractableValues(valsAndDts)) + if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) => + val values = valsAndDts.map(_._1) Some(convertInToOr(name, values)) case op @ SpecialBinaryComparison( - ExtractAttribute(SupportedAttribute(name)), ExtractableLiteral(value)) => + ExtractAttribute(SupportedAttribute(name), dt1), ExtractableLiteral(value, dt2)) + if compatibleTypes(dt1, dt2) => Some(s"$name ${op.symbol} $value") case op @ SpecialBinaryComparison( - ExtractableLiteral(value), ExtractAttribute(SupportedAttribute(name))) => + ExtractAttribute(SupportedAttribute(name), dt1), ExtractableLiteral(rawValue, dt2)) + if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] => + fixValue(rawValue, dt1).map { value => + s"$name ${op.symbol} $value" + } + + case op @ SpecialBinaryComparison( + ExtractableLiteral(value, dt2), ExtractAttribute(SupportedAttribute(name), dt1)) + if compatibleTypes(dt1, dt2) => Some(s"$value ${op.symbol} $name") + case op @ SpecialBinaryComparison( + ExtractableLiteral(rawValue, dt2), ExtractAttribute(SupportedAttribute(name), dt1)) + if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] => + fixValue(rawValue, dt1).map { value => + s"$value ${op.symbol} $name" + } + case And(expr1, expr2) if useAdvanced => val converted = convert(expr1) ++ convert(expr2) if (converted.isEmpty) { @@ -795,6 +841,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val partitions = if (filter.isEmpty) { + logDebug(s"Falling back to getting all partitions") getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { logDebug(s"Hive metastore filter is '$filter'.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 2a4efd0cce6e..567510216c50 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -60,9 +60,49 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { "1 = intcol") filterTest("int and string filter", - (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", StringType)) :: Nil, "1 = intcol and \"a\" = strcol") + filterTest("int and int column/string literal filter", + (a("intcol1", IntegerType) === Literal(1)) :: + (a("intcol2", IntegerType) === (Literal("a"))) :: Nil, + "intcol1 = 1") + + filterTest("int and int column/string literal filter with conversion", + (a("intcol1", IntegerType) === Literal(1)) :: + (a("intcol2", IntegerType) === (Literal("00002"))) :: Nil, + "intcol1 = 1 and intcol2 = 2") + + filterTest("int and int column/string literal filter backwards", + (Literal(1) === a("intcol1", IntegerType)) :: + (Literal("a") === a("intcol2", IntegerType)) :: Nil, + "1 = intcol1") + + filterTest("int and int column/string literal filter backwards with conversion", + (Literal(1) === a("intcol1", IntegerType)) :: + (Literal("00002") === a("intcol2", IntegerType)) :: Nil, + "1 = intcol1 and 2 = intcol2") + + filterTest("int filter with in", + (a("intcol", IntegerType) in (Literal(1), Literal(2))) :: Nil, + "(intcol = 1 or intcol = 2)") + + filterTest("int/string filter with in", + (a("intcol", IntegerType) in (Literal("1"), Literal("2"))) :: Nil, + "") + + filterTest("int filter with inset", + (a("intcol", IntegerType) in ((0 to 11).map(Literal(_)): _*)) :: Nil, + "(" + (0 to 11).map(x => s"intcol = $x").mkString(" or ") + ")") + + filterTest("int/string filter with inset", + (a("intcol", IntegerType) in ((0 to 11).map(x => Literal(x.toString)): _*)) :: Nil, + "") + + filterTest("string filter with in", + (a("strcol", StringType) in (Literal("1"), Literal("2"))) :: Nil, + "(strcol = \"1\" or strcol = \"2\")") + filterTest("skip varchar", (Literal("") === a("varchar", StringType)) :: Nil, "") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index 06aea084330f..fa03c2b48230 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -80,6 +80,19 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { } } + test("SPARK-33098: Don't push partition filter with mismatched datatypes to metastore") { + withTable("t") { + sql("create table t (a int) partitioned by (b int) stored as parquet") + + // There are only two test cases because TestHiveSparkSession sets + // hive.metastore.integral.jdo.pushdown=true, which, as a side effect, prevents + // the Metaexception for most of the problem cases. Only the two cases below + // would still throw a MetaException when Hive is configured this way + sql("select cast(b as string) b from t").filter("b > '1'").collect + sql("select * from t where cast(b as string) > '1'").collect + } + } + override def getScanExecPartitionSize(plan: SparkPlan): Long = { plan.collectFirst { case p: HiveTableScanExec => p