diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 095b0aaddb65..cee795d05100 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -679,7 +679,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) - case Cast(child @ AtomicType(), dt: AtomicType, _) + case Cast(child @ IntegralType(), dt: IntegralType, _) if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 5bdb13aec012..2568de409f6b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StringType} import org.apache.spark.util.Utils // TODO: Refactor this to `HivePartitionFilteringSuite` @@ -258,6 +258,13 @@ class HiveClientSuite(version: String) buildClient(new Configuration(), sharesHadoopClasses = false) } + test("getPartitionsByFilter: chunk in ('ab', 'ba') and ((cast(ds as string)>'20170102')") { + val day = (20170101 to 20170103, 0 to 23, Seq("ab", "ba")) + testMetastorePartitionFiltering( + attr("chunk").in("ab", "ba") && (attr("ds").cast(StringType) > "20170102"), + day :: Nil) + } + private def testMetastorePartitionFiltering( filterExpr: Expression, expectedDs: Seq[Int],