diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b4a8bafe22df..cc96d905a86f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import java.time.LocalDate import scala.language.implicitConversions @@ -146,6 +147,7 @@ package object dsl { implicit def doubleToLiteral(d: Double): Literal = Literal(d) implicit def stringToLiteral(s: String): Literal = Literal.create(s, StringType) implicit def dateToLiteral(d: Date): Literal = Literal(d) + implicit def localDateToLiteral(d: LocalDate): Literal = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying()) implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) implicit def decimalToLiteral(d: Decimal): Literal = Literal(d) diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index f5abd30854e0..a01d5a44da71 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc +import java.time.LocalDate + import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder @@ -24,6 +26,7 @@ import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -162,6 +165,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { value.asInstanceOf[Number].doubleValue() case _: DecimalType => new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) + case _: DateType if value.isInstanceOf[LocalDate] => + toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) case _ => value } diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index ee5162bced8a..a1c325e7bb87 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -299,26 +299,33 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - date") { - val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day => + val input = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day => Date.valueOf(day) } - withOrcDataFrame(dates.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - - checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) - - checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN) - - checkFilterPredicate(Literal(dates(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(dates(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(dates(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(dates(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(dates(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(dates(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val dates = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(dates(0) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(dates(0) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(dates(1) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(dates(2) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(dates(0) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(dates(3) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 675e08915367..445a52cece1c 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc +import java.time.LocalDate + import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder @@ -24,6 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -162,6 +165,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { value.asInstanceOf[Number].doubleValue() case _: DecimalType => new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) + case _: DateType if value.isInstanceOf[LocalDate] => + toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) case _ => value } diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 1baa69e82bb1..815af05beb00 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -300,26 +300,33 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - date") { - val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day => + val input = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day => Date.valueOf(day) } - withOrcDataFrame(dates.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - - checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) - - checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN) - - checkFilterPredicate(Literal(dates(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(dates(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(dates(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(dates(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(dates(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(dates(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val dates = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(dates(0) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(dates(0) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(dates(1) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(dates(2) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(dates(0) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(dates(3) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } }