Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.Instant
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand All @@ -29,6 +30,7 @@ import scala.language.existentials
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -61,6 +63,7 @@ object CatalystTypeConverters {
case structType: StructType => StructConverter(structType)
case StringType => StringConverter
case DateType => DateConverter
case TimestampType if SQLConf.get.timestampExternalType == "Instant" => InstantConverter
case TimestampType => TimestampConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
Expand Down Expand Up @@ -315,6 +318,16 @@ object CatalystTypeConverters {
DateTimeUtils.toJavaTimestamp(row.getLong(column))
}

private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] {
override def toCatalystImpl(scalaValue: Instant): Long =
DateTimeUtils.instantToMicros(scalaValue)
override def toScala(catalystValue: Any): Instant =
if (catalystValue == null) null
else DateTimeUtils.microsToInstant(catalystValue.asInstanceOf[Long])
override def toScalaImpl(row: InternalRow, column: Int): Instant =
DateTimeUtils.microsToInstant(row.getLong(column))
}

private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
override def toCatalystImpl(scalaValue: Any): Decimal = {
Expand Down Expand Up @@ -421,6 +434,7 @@ object CatalystTypeConverters {
case s: String => StringConverter.toCatalyst(s)
case d: Date => DateConverter.toCatalyst(d)
case t: Timestamp => TimestampConverter.toCatalyst(t)
case i: Instant => InstantConverter.toCatalyst(i)
case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForInstant(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.Instant]),
"microsToInstant",
path :: Nil,
returnNullable = false)
}

def createDeserializerForSqlTimestamp(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Date] =>
createDeserializerForSqlDate(path)

case c if c == classOf[java.time.Instant] =>
createDeserializerForInstant(path)

case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.sql.Date] =>
createDeserializerForSqlDate(path)

case t if t <:< localTypeOf[java.time.Instant] =>
createDeserializerForInstant(path)

case t if t <:< localTypeOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

Expand Down Expand Up @@ -474,6 +477,14 @@ object ScalaReflection extends ScalaReflection {
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.time.Instant] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"instantToMicros",
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down Expand Up @@ -691,6 +702,7 @@ object ScalaReflection extends ScalaReflection {
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -46,7 +47,8 @@ import org.apache.spark.unsafe.types.UTF8String
* DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal
*
* DateType -> java.sql.Date
* TimestampType -> java.sql.Timestamp
* TimestampType -> java.sql.Timestamp when spark.sql.catalyst.timestampType is set to Timestamp
* TimestampType -> java.time.Instant when spark.sql.catalyst.timestampType is set to Instant
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
Expand Down Expand Up @@ -89,6 +91,14 @@ object RowEncoder {
dataType = ObjectType(udtClass), false)
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)

case TimestampType if SQLConf.get.timestampExternalType == "Instant" =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"instantToMicros",
inputObject :: Nil,
returnNullable = false)

case TimestampType =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down Expand Up @@ -224,6 +234,8 @@ object RowEncoder {

def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case TimestampType if SQLConf.get.timestampExternalType == "Instant" =>
ObjectType(classOf[java.time.Instant])
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
Expand Down Expand Up @@ -267,6 +279,14 @@ object RowEncoder {
dataType = ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)

case TimestampType if SQLConf.get.timestampExternalType == "Instant" =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.Instant]),
"microsToInstant",
input :: Nil,
returnNullable = false)

case TimestampType =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ object DateTimeUtils {
result
}

def microsToInstant(us: Long): Instant = {
val secs = Math.floorDiv(us, MICROS_PER_SECOND)
val mos = Math.floorMod(us, MICROS_PER_SECOND)
Instant.ofEpochSecond(secs, mos * NANOS_PER_MICROS)
}

def instantToDays(instant: Instant): Int = {
val seconds = instant.getEpochSecond
val days = Math.floorDiv(seconds, SECONDS_PER_DAY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ class Iso8601TimestampFormatter(
override def parse(s: String): Long = instantToMicros(toInstant(s))

override def format(us: Long): String = {
val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND)
val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND)
val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS)

val instant = DateTimeUtils.microsToInstant(us)
formatter.withZone(timeZone.toZoneId).format(instant)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,12 @@ object SQLConf {
"a SparkConf entry.")
.booleanConf
.createWithDefault(true)

val TIMESTAMP_EXTERNAL_TYPE = buildConf("spark.sql.catalyst.timestampType")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can support reading from both types at the same time right?
I don't know if it's worth changing what it is written to; not worth a flag IMHO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can support reading from both types at the same time right?

At Spark side, we can read both.

I don't know if it's worth changing what it is written to; not worth a flag IMHO.

Timestamps can be loaded from a datasource, casted from other types and etc. If an user wants to imports (collect) non-legacy timestamps (I mean java.time.Instant), how she/he can do that without the flag?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import is fine; we could potentially read both types to TimestampType. Can we just be opinionated about the right way to write it back out, and keep current behavior? it may be 'legacy' but not sure it's worth the behavior change. You may have more context on why that's important though.

As with many things I just don't know how realistically people will understand the issue, find the flag, set it, and maintain it across deployments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be 'legacy' but not sure it's worth the behavior change.

The SQL config spark.sql.catalyst.timestampType has default value Timestamp which preserves current behavior. When an user wants to import java.time.Instant from Spark, she/he can change the config to point out the Java timestamp class.

.doc("Java class to/from which an instance of TimestampType is converted.")
.stringConf
.checkValues(Set("Timestamp", "Instant"))
.createWithDefault("Timestamp")
}

/**
Expand Down Expand Up @@ -1869,6 +1875,8 @@ class SQLConf extends Serializable with Logging {

def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)

def timestampExternalType: String = getConf(TIMESTAMP_EXTERNAL_TYPE)

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.sql.catalyst

import java.time.Instant

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class CatalystTypeConvertersSuite extends SparkFunSuite {
class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {

private val simpleTypes: Seq[DataType] = Seq(
StringType,
Expand Down Expand Up @@ -147,4 +151,34 @@ class CatalystTypeConvertersSuite extends SparkFunSuite {
val expected = UTF8String.fromString("X")
assert(converter(chr) === expected)
}

test("converting java.time.Instant to TimestampType") {
Seq(
"0101-02-16T10:11:32Z",
"1582-10-02T01:02:03.04Z",
"1582-12-31T23:59:59.999999Z",
"1970-01-01T00:00:01.123Z",
"1972-12-31T23:59:59.123456Z",
"2019-02-16T18:12:30Z",
"2119-03-16T19:13:31Z").foreach { timestamp =>
val input = Instant.parse(timestamp)
val result = CatalystTypeConverters.convertToCatalyst(input)
val expected = DateTimeUtils.instantToMicros(input)
assert(result === expected)
}
}

test("converting TimestampType to java.time.Instant") {
withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") {
Seq(
-9463427405253013L,
-244000001L,
0L,
99628200102030L,
1543749753123456L).foreach { us =>
val instant = DateTimeUtils.microsToInstant(us)
assert(CatalystTypeConverters.createToScalaConverter(TimestampType)(us) === instant)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import scala.util.Random

import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
Expand Down Expand Up @@ -281,6 +282,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(encoder.serializer(0).dataType == pythonUDT.sqlType)
}

test("encoding/decoding TimestampType to/from java.time.Instant") {
withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") {
val schema = new StructType().add("t", TimestampType)
val encoder = RowEncoder(schema).resolveAndBind()
val instant = java.time.Instant.parse("2019-02-26T16:56:00Z")
val row = encoder.toRow(Row(instant))
assert(row.getLong(0) === DateTimeUtils.instantToMicros(instant))
val readback = encoder.fromRow(row)
assert(readback.get(0) === instant)
}
}

for {
elementType <- Seq(IntegerType, StringType)
containsNull <- Seq(true, false)
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.functions.{lit, udf}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -493,4 +494,14 @@ class UDFSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Seq(Row(1L), Row(2L)))
}
}

test("Using java.time.Instant in UDF") {
withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") {
val expected = java.time.Instant.parse("2019-02-27T00:00:00Z")
val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1))
val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t")
.select(plusSec('t))
assert(df.collect().toSeq === Seq(Row(expected)))
}
}
}