diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index a20625b5d5f5..925d12c16bae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} -import java.time.Instant +import java.time.{Instant, LocalDate} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -62,8 +62,9 @@ object CatalystTypeConverters { case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) case structType: StructType => StructConverter(structType) case StringType => StringConverter + case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter case DateType => DateConverter - case TimestampType if SQLConf.get.timestampExternalType == "Instant" => InstantConverter + case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter case TimestampType => TimestampConverter case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter @@ -308,6 +309,18 @@ object CatalystTypeConverters { DateTimeUtils.toJavaDate(row.getInt(column)) } + private object LocalDateConverter extends CatalystTypeConverter[LocalDate, LocalDate, Any] { + override def toCatalystImpl(scalaValue: LocalDate): Int = { + DateTimeUtils.localDateToDays(scalaValue) + } + override def toScala(catalystValue: Any): LocalDate = { + if (catalystValue == null) null + else DateTimeUtils.daysToLocalDate(catalystValue.asInstanceOf[Int]) + } + override def toScalaImpl(row: InternalRow, column: Int): LocalDate = + DateTimeUtils.daysToLocalDate(row.getInt(column)) + } + private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { override def toCatalystImpl(scalaValue: Timestamp): Long = DateTimeUtils.fromJavaTimestamp(scalaValue) @@ -433,6 +446,7 @@ object CatalystTypeConverters { def convertToCatalyst(a: Any): Any = a match { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) + case ld: LocalDate => LocalDateConverter.toCatalyst(ld) case t: Timestamp => TimestampConverter.toCatalyst(t) case i: Instant => InstantConverter.toCatalyst(i) case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 3a2f38622d00..d75d3ca918c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -101,6 +101,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForLocalDate(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.LocalDate]), + "daysToLocalDate", + path :: Nil, + returnNullable = false) + } + def createDeserializerForInstant(path: Expression): Expression = { StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1822f9b036f7..87b2ae8cdf7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -221,6 +221,9 @@ object JavaTypeInference { c == classOf[java.lang.Boolean] => createDeserializerForTypesSupportValueOf(path, c) + case c if c == classOf[java.time.LocalDate] => + createDeserializerForLocalDate(path) + case c if c == classOf[java.sql.Date] => createDeserializerForSqlDate(path) @@ -393,6 +396,14 @@ object JavaTypeInference { inputObject :: Nil, returnNullable = false) + case c if c == classOf[java.time.LocalDate] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "localDateToDays", + inputObject :: Nil, + returnNullable = false) + case c if c == classOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 26cc7b4d7ad8..bbddd3312a58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -194,6 +194,9 @@ object ScalaReflection extends ScalaReflection { createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean]) + case t if t <:< localTypeOf[java.time.LocalDate] => + createDeserializerForLocalDate(path) + case t if t <:< localTypeOf[java.sql.Date] => createDeserializerForSqlDate(path) @@ -493,6 +496,14 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) + case t if t <:< localTypeOf[java.time.LocalDate] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "localDateToDays", + inputObject :: Nil, + returnNullable = false) + case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, @@ -704,6 +715,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< localTypeOf[java.time.LocalDate] => Schema(DateType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index eba68810790c..68a603b95ad5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -46,9 +46,11 @@ import org.apache.spark.unsafe.types.UTF8String * StringType -> String * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal * - * DateType -> java.sql.Date - * TimestampType -> java.sql.Timestamp when spark.sql.catalyst.timestampType is set to Timestamp - * TimestampType -> java.time.Instant when spark.sql.catalyst.timestampType is set to Instant + * DateType -> java.sql.Date if spark.sql.datetime.java8API.enabled is false + * DateType -> java.time.LocalDate if spark.sql.datetime.java8API.enabled is true + * + * TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false + * TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true * * BinaryType -> byte array * ArrayType -> scala.collection.Seq or Array @@ -91,7 +93,7 @@ object RowEncoder { dataType = ObjectType(udtClass), false) Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) - case TimestampType if SQLConf.get.timestampExternalType == "Instant" => + case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => StaticInvoke( DateTimeUtils.getClass, TimestampType, @@ -107,6 +109,14 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) + case DateType if SQLConf.get.datetimeJava8ApiEnabled => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "localDateToDays", + inputObject :: Nil, + returnNullable = false) + case DateType => StaticInvoke( DateTimeUtils.getClass, @@ -234,9 +244,11 @@ object RowEncoder { def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt - case TimestampType if SQLConf.get.timestampExternalType == "Instant" => + case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => ObjectType(classOf[java.time.Instant]) case TimestampType => ObjectType(classOf[java.sql.Timestamp]) + case DateType if SQLConf.get.datetimeJava8ApiEnabled => + ObjectType(classOf[java.time.LocalDate]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) case StringType => ObjectType(classOf[java.lang.String]) @@ -279,7 +291,7 @@ object RowEncoder { dataType = ObjectType(udtClass)) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) - case TimestampType if SQLConf.get.timestampExternalType == "Instant" => + case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.time.Instant]), @@ -295,6 +307,14 @@ object RowEncoder { input :: Nil, returnNullable = false) + case DateType if SQLConf.get.datetimeJava8ApiEnabled => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.LocalDate]), + "daysToLocalDate", + input :: Nil, + returnNullable = false) + case DateType => StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 742cee6255bf..5064220b9562 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -367,6 +367,10 @@ object DateTimeUtils { days.toInt } + def localDateToDays(localDate: LocalDate): Int = localDate.toEpochDay.toInt + + def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days) + /** * Trim and parse a given UTF8 date string to a corresponding [[Int]] value. * The return type is [[Option]] in order to distinguish between 0 and null. The following diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2cf471e47d74..380f5cfd8864 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1683,11 +1683,13 @@ object SQLConf { .booleanConf .createWithDefault(true) - val TIMESTAMP_EXTERNAL_TYPE = buildConf("spark.sql.catalyst.timestampType") - .doc("Java class to/from which an instance of TimestampType is converted.") - .stringConf - .checkValues(Set("Timestamp", "Instant")) - .createWithDefault("Timestamp") + val DATETIME_JAVA8API_EANBLED = buildConf("spark.sql.datetime.java8API.enabled") + .doc("If the configuration property is set to true, java.time.Instant and " + + "java.time.LocalDate classes of Java 8 API are used as external types for " + + "Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " + + "and java.sql.Date are used for the same purpose.") + .booleanConf + .createWithDefault(false) } /** @@ -1875,7 +1877,7 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) - def timestampExternalType: String = getConf(TIMESTAMP_EXTERNAL_TYPE) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_EANBLED) /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 828e8d89977b..6999526c801c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.time.Instant +import java.time.{Instant, LocalDate} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -169,7 +169,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { } test("converting TimestampType to java.time.Instant") { - withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") { + withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") { Seq( -9463427405253013L, -244000001L, @@ -181,4 +181,39 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { } } } + + test("converting java.time.LocalDate to DateType") { + Seq( + "0101-02-16", + "1582-10-02", + "1582-12-31", + "1970-01-01", + "1972-12-31", + "2019-02-16", + "2119-03-16").foreach { timestamp => + val input = LocalDate.parse(timestamp) + val result = CatalystTypeConverters.convertToCatalyst(input) + val expected = DateTimeUtils.localDateToDays(input) + assert(result === expected) + } + } + + test("converting DateType to java.time.LocalDate") { + withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") { + Seq( + -701265, + -371419, + -199722, + -1, + 0, + 967, + 2094, + 17877, + 24837, + 1110657).foreach { days => + val localDate = DateTimeUtils.daysToLocalDate(days) + assert(CatalystTypeConverters.createToScalaConverter(DateType)(days) === localDate) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 691056db6e7c..79c8abbcdc91 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -283,7 +283,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } test("encoding/decoding TimestampType to/from java.time.Instant") { - withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") { + withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") { val schema = new StructType().add("t", TimestampType) val encoder = RowEncoder(schema).resolveAndBind() val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") @@ -294,6 +294,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } } + test("encoding/decoding DateType to/from java.time.LocalDate") { + withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") { + val schema = new StructType().add("d", DateType) + val encoder = RowEncoder(schema).resolveAndBind() + val localDate = java.time.LocalDate.parse("2019-02-27") + val row = encoder.toRow(Row(localDate)) + assert(row.getLong(0) === DateTimeUtils.localDateToDays(localDate)) + val readback = encoder.fromRow(row) + assert(readback.get(0).equals(localDate)) + } + } + for { elementType <- Seq(IntegerType, StringType) containsNull <- Seq(true, false) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 5bf188882618..a8ba8753aab5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -18,8 +18,10 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.time.LocalDate; import java.util.List; +import org.apache.spark.sql.internal.SQLConf; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -121,4 +123,20 @@ public void udf6Test() { Row result = spark.sql("SELECT returnOne()").head(); Assert.assertEquals(1, result.getInt(0)); } + + @SuppressWarnings("unchecked") + @Test + public void udf7Test() { + String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_EANBLED().key()); + try { + spark.conf().set(SQLConf.DATETIME_JAVA8API_EANBLED().key(), "true"); + spark.udf().register( + "plusDay", + (java.time.LocalDate ld) -> ld.plusDays(1), DataTypes.DateType); + Row result = spark.sql("SELECT plusDay(DATE '2019-02-26')").head(); + Assert.assertEquals(LocalDate.parse("2019-02-27"), result.get(0)); + } finally { + spark.conf().set(SQLConf.DATETIME_JAVA8API_EANBLED().key(), originConf); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0f5f0efb0fe7..794c597d3f65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -496,7 +496,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Using java.time.Instant in UDF") { - withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") { + withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") { val expected = java.time.Instant.parse("2019-02-27T00:00:00Z") val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1)) val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t") @@ -504,4 +504,14 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(df.collect().toSeq === Seq(Row(expected))) } } + + test("Using java.time.LocalDate in UDF") { + withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") { + val expected = java.time.LocalDate.parse("2019-02-27") + val plusDay = udf((i: java.time.LocalDate) => i.plusDays(1)) + val df = spark.sql("SELECT DATE '2019-02-26' as d") + .select(plusDay('d)) + assert(df.collect().toSeq === Seq(Row(expected))) + } + } }