diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fe3fea5e35b1..26f5bee72092 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import java.time.LocalDate +import java.time.{Instant, LocalDate} import scala.language.implicitConversions @@ -152,6 +152,7 @@ package object dsl { implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) implicit def decimalToLiteral(d: Decimal): Literal = Literal(d) implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t) + implicit def instantToLiteral(i: Instant): Literal = Literal(i) implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a) implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index a01d5a44da71..b68563956c82 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.LocalDate +import java.time.{Instant, LocalDate} import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,7 +26,7 @@ import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -167,6 +167,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _: DateType if value.isInstanceOf[LocalDate] => toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) + case _: TimestampType if value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) case _ => value } diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index a1c325e7bb87..88b4b243b543 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -245,29 +245,41 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - timestamp") { - val timeString = "2015-08-20 14:57:00" - val timestamps = (1 to 4).map { i => - val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 - new Timestamp(milliseconds) - } - withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + val input = Seq( + "1000-01-01 01:02:03", + "1582-10-01 00:11:22", + "1900-01-01 23:59:59", + "2020-05-25 10:11:12").map(Timestamp.valueOf) - checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) - - checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) - - checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(timestamps(0)) <=> $"_1", - PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(timestamps(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val timestamps = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(timestamps(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 445a52cece1c..4b642080d25a 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.LocalDate +import java.time.{Instant, LocalDate} import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -167,6 +167,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _: DateType if value.isInstanceOf[LocalDate] => toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) + case _: TimestampType if value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) case _ => value } diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 815af05beb00..2263179515a5 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -246,29 +246,41 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - timestamp") { - val timeString = "2015-08-20 14:57:00" - val timestamps = (1 to 4).map { i => - val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 - new Timestamp(milliseconds) - } - withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - - checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val input = Seq( + "1000-01-01 01:02:03", + "1582-10-01 00:11:22", + "1900-01-01 23:59:59", + "2020-05-25 10:11:12").map(Timestamp.valueOf) - checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val timestamps = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate( - Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(timestamps(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(timestamps(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } }