diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 93df73ab1eaf..a20625b5d5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,6 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} +import java.time.Instant import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -29,6 +30,7 @@ import scala.language.existentials import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -61,6 +63,7 @@ object CatalystTypeConverters { case structType: StructType => StructConverter(structType) case StringType => StringConverter case DateType => DateConverter + case TimestampType if SQLConf.get.timestampExternalType == "Instant" => InstantConverter case TimestampType => TimestampConverter case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter @@ -315,6 +318,16 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } + private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] { + override def toCatalystImpl(scalaValue: Instant): Long = + DateTimeUtils.instantToMicros(scalaValue) + override def toScala(catalystValue: Any): Instant = + if (catalystValue == null) null + else DateTimeUtils.microsToInstant(catalystValue.asInstanceOf[Long]) + override def toScalaImpl(row: InternalRow, column: Int): Instant = + DateTimeUtils.microsToInstant(row.getLong(column)) + } + private class DecimalConverter(dataType: DecimalType) extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = { @@ -421,6 +434,7 @@ object CatalystTypeConverters { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) case t: Timestamp => TimestampConverter.toCatalyst(t) + case i: Instant => InstantConverter.toCatalyst(i) case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index e71955ab4e75..3a2f38622d00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -101,6 +101,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForInstant(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.Instant]), + "microsToInstant", + path :: Nil, + returnNullable = false) + } + def createDeserializerForSqlTimestamp(path: Expression): Expression = { StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index dafa87839ec6..1822f9b036f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -224,6 +224,9 @@ object JavaTypeInference { case c if c == classOf[java.sql.Date] => createDeserializerForSqlDate(path) + case c if c == classOf[java.time.Instant] => + createDeserializerForInstant(path) + case c if c == classOf[java.sql.Timestamp] => createDeserializerForSqlTimestamp(path) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 741cba80640b..26cc7b4d7ad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -197,6 +197,9 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Date] => createDeserializerForSqlDate(path) + case t if t <:< localTypeOf[java.time.Instant] => + createDeserializerForInstant(path) + case t if t <:< localTypeOf[java.sql.Timestamp] => createDeserializerForSqlTimestamp(path) @@ -474,6 +477,14 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) + case t if t <:< localTypeOf[java.time.Instant] => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "instantToMicros", + inputObject :: Nil, + returnNullable = false) + case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, @@ -691,6 +702,7 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) + case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 8ca3d356f3bd..eba68810790c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -46,7 +47,8 @@ import org.apache.spark.unsafe.types.UTF8String * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal * * DateType -> java.sql.Date - * TimestampType -> java.sql.Timestamp + * TimestampType -> java.sql.Timestamp when spark.sql.catalyst.timestampType is set to Timestamp + * TimestampType -> java.time.Instant when spark.sql.catalyst.timestampType is set to Instant * * BinaryType -> byte array * ArrayType -> scala.collection.Seq or Array @@ -89,6 +91,14 @@ object RowEncoder { dataType = ObjectType(udtClass), false) Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) + case TimestampType if SQLConf.get.timestampExternalType == "Instant" => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "instantToMicros", + inputObject :: Nil, + returnNullable = false) + case TimestampType => StaticInvoke( DateTimeUtils.getClass, @@ -224,6 +234,8 @@ object RowEncoder { def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt + case TimestampType if SQLConf.get.timestampExternalType == "Instant" => + ObjectType(classOf[java.time.Instant]) case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -267,6 +279,14 @@ object RowEncoder { dataType = ObjectType(udtClass)) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + case TimestampType if SQLConf.get.timestampExternalType == "Instant" => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.Instant]), + "microsToInstant", + input :: Nil, + returnNullable = false) + case TimestampType => StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 627670afec9f..742cee6255bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -355,6 +355,12 @@ object DateTimeUtils { result } + def microsToInstant(us: Long): Instant = { + val secs = Math.floorDiv(us, MICROS_PER_SECOND) + val mos = Math.floorMod(us, MICROS_PER_SECOND) + Instant.ofEpochSecond(secs, mos * NANOS_PER_MICROS) + } + def instantToDays(instant: Instant): Int = { val seconds = instant.getEpochSecond val days = Math.floorDiv(seconds, SECONDS_PER_DAY) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 4ec61e1ca4a5..c25481586e73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -61,10 +61,7 @@ class Iso8601TimestampFormatter( override def parse(s: String): Long = instantToMicros(toInstant(s)) override def format(us: Long): String = { - val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND) - val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND) - val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS) - + val instant = DateTimeUtils.microsToInstant(us) formatter.withZone(timeZone.toZoneId).format(instant) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e74c2af476ee..2cf471e47d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1682,6 +1682,12 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) + + val TIMESTAMP_EXTERNAL_TYPE = buildConf("spark.sql.catalyst.timestampType") + .doc("Java class to/from which an instance of TimestampType is converted.") + .stringConf + .checkValues(Set("Timestamp", "Instant")) + .createWithDefault("Timestamp") } /** @@ -1869,6 +1875,8 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def timestampExternalType: String = getConf(TIMESTAMP_EXTERNAL_TYPE) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 89452ee05cff..828e8d89977b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.catalyst +import java.time.Instant + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class CatalystTypeConvertersSuite extends SparkFunSuite { +class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { private val simpleTypes: Seq[DataType] = Seq( StringType, @@ -147,4 +151,34 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { val expected = UTF8String.fromString("X") assert(converter(chr) === expected) } + + test("converting java.time.Instant to TimestampType") { + Seq( + "0101-02-16T10:11:32Z", + "1582-10-02T01:02:03.04Z", + "1582-12-31T23:59:59.999999Z", + "1970-01-01T00:00:01.123Z", + "1972-12-31T23:59:59.123456Z", + "2019-02-16T18:12:30Z", + "2119-03-16T19:13:31Z").foreach { timestamp => + val input = Instant.parse(timestamp) + val result = CatalystTypeConverters.convertToCatalyst(input) + val expected = DateTimeUtils.instantToMicros(input) + assert(result === expected) + } + } + + test("converting TimestampType to java.time.Instant") { + withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") { + Seq( + -9463427405253013L, + -244000001L, + 0L, + 99628200102030L, + 1543749753123456L).foreach { us => + val instant = DateTimeUtils.microsToInstant(us) + assert(CatalystTypeConverters.createToScalaConverter(TimestampType)(us) === instant) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index ab819bec72e8..691056db6e7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -21,7 +21,8 @@ import scala.util.Random import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) @@ -281,6 +282,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { assert(encoder.serializer(0).dataType == pythonUDT.sqlType) } + test("encoding/decoding TimestampType to/from java.time.Instant") { + withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") { + val schema = new StructType().add("t", TimestampType) + val encoder = RowEncoder(schema).resolveAndBind() + val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") + val row = encoder.toRow(Row(instant)) + assert(row.getLong(0) === DateTimeUtils.instantToMicros(instant)) + val readback = encoder.fromRow(row) + assert(readback.get(0) === instant) + } + } + for { elementType <- Seq(IntegerType, StringType) containsNull <- Seq(true, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e515800e6898..0f5f0efb0fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.functions.{lit, udf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ @@ -493,4 +494,14 @@ class UDFSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(1L), Row(2L))) } } + + test("Using java.time.Instant in UDF") { + withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") { + val expected = java.time.Instant.parse("2019-02-27T00:00:00Z") + val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1)) + val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t") + .select(plusSec('t)) + assert(df.collect().toSeq === Seq(Row(expected))) + } + } }