diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8ccb17ce35925..db67480ceb77a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -700,7 +700,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def unapply(values: Set[Any]): Option[Seq[String]] = { - val extractables = values.toSeq.map(valueToLiteralString.lift) + val extractables = values.filter(_ != null).toSeq.map(valueToLiteralString.lift) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { @@ -715,7 +715,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def unapply(values: Set[Any]): Option[Seq[String]] = { - val extractables = values.toSeq.map(valueToLiteralString.lift) + val extractables = values.filter(_ != null).toSeq.map(valueToLiteralString.lift) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 6962f9dd6b186..79b34bd141de3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -187,5 +187,20 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { } } + test("SPARK-34538: Skip InSet null value during push filter to Hive metastore") { + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "3") { + val intFilter = InSet(a("p", IntegerType), Set(null, 1, 2)) + val intConverted = shim.convertFilters(testTable, Seq(intFilter), conf.sessionLocalTimeZone) + assert(intConverted == "(p = 1 or p = 2)") + } + + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "3") { + val dateFilter = InSet(a("p", DateType), Set(null, + Literal(Date.valueOf("2020-01-01")).eval(), Literal(Date.valueOf("2021-01-01")).eval())) + val dateConverted = shim.convertFilters(testTable, Seq(dateFilter), conf.sessionLocalTimeZone) + assert(dateConverted == "(p = 2020-01-01 or p = 2021-01-01)") + } + } + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() }