Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.Instant
import java.time.{Instant, LocalDate}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand Down Expand Up @@ -62,8 +62,9 @@ object CatalystTypeConverters {
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
case structType: StructType => StructConverter(structType)
case StringType => StringConverter
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
case TimestampType if SQLConf.get.timestampExternalType == "Instant" => InstantConverter
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
case TimestampType => TimestampConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
Expand Down Expand Up @@ -308,6 +309,18 @@ object CatalystTypeConverters {
DateTimeUtils.toJavaDate(row.getInt(column))
}

private object LocalDateConverter extends CatalystTypeConverter[LocalDate, LocalDate, Any] {
override def toCatalystImpl(scalaValue: LocalDate): Int = {
DateTimeUtils.localDateToDays(scalaValue)
}
override def toScala(catalystValue: Any): LocalDate = {
if (catalystValue == null) null
else DateTimeUtils.daysToLocalDate(catalystValue.asInstanceOf[Int])
}
override def toScalaImpl(row: InternalRow, column: Int): LocalDate =
DateTimeUtils.daysToLocalDate(row.getInt(column))
}

private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
override def toCatalystImpl(scalaValue: Timestamp): Long =
DateTimeUtils.fromJavaTimestamp(scalaValue)
Expand Down Expand Up @@ -433,6 +446,7 @@ object CatalystTypeConverters {
def convertToCatalyst(a: Any): Any = a match {
case s: String => StringConverter.toCatalyst(s)
case d: Date => DateConverter.toCatalyst(d)
case ld: LocalDate => LocalDateConverter.toCatalyst(ld)
case t: Timestamp => TimestampConverter.toCatalyst(t)
case i: Instant => InstantConverter.toCatalyst(i)
case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForLocalDate(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.LocalDate]),
"daysToLocalDate",
path :: Nil,
returnNullable = false)
}

def createDeserializerForInstant(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ object JavaTypeInference {
c == classOf[java.lang.Boolean] =>
createDeserializerForTypesSupportValueOf(path, c)

case c if c == classOf[java.time.LocalDate] =>
createDeserializerForLocalDate(path)

case c if c == classOf[java.sql.Date] =>
createDeserializerForSqlDate(path)

Expand Down Expand Up @@ -393,6 +396,14 @@ object JavaTypeInference {
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.time.LocalDate] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"localDateToDays",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ object ScalaReflection extends ScalaReflection {
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Boolean])

case t if t <:< localTypeOf[java.time.LocalDate] =>
createDeserializerForLocalDate(path)

case t if t <:< localTypeOf[java.sql.Date] =>
createDeserializerForSqlDate(path)

Expand Down Expand Up @@ -493,6 +496,14 @@ object ScalaReflection extends ScalaReflection {
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.time.LocalDate] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"localDateToDays",
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down Expand Up @@ -704,6 +715,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.time.LocalDate] => Schema(DateType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ import org.apache.spark.unsafe.types.UTF8String
* StringType -> String
* DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal
*
* DateType -> java.sql.Date
* TimestampType -> java.sql.Timestamp when spark.sql.catalyst.timestampType is set to Timestamp
* TimestampType -> java.time.Instant when spark.sql.catalyst.timestampType is set to Instant
* DateType -> java.sql.Date if spark.sql.datetime.java8API.enabled is false
* DateType -> java.time.LocalDate if spark.sql.datetime.java8API.enabled is true
*
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
Expand Down Expand Up @@ -91,7 +93,7 @@ object RowEncoder {
dataType = ObjectType(udtClass), false)
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)

case TimestampType if SQLConf.get.timestampExternalType == "Instant" =>
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
Expand All @@ -107,6 +109,14 @@ object RowEncoder {
inputObject :: Nil,
returnNullable = false)

case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"localDateToDays",
inputObject :: Nil,
returnNullable = false)

case DateType =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down Expand Up @@ -234,9 +244,11 @@ object RowEncoder {

def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case TimestampType if SQLConf.get.timestampExternalType == "Instant" =>
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
ObjectType(classOf[java.time.Instant])
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
ObjectType(classOf[java.time.LocalDate])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
Expand Down Expand Up @@ -279,7 +291,7 @@ object RowEncoder {
dataType = ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)

case TimestampType if SQLConf.get.timestampExternalType == "Instant" =>
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.Instant]),
Expand All @@ -295,6 +307,14 @@ object RowEncoder {
input :: Nil,
returnNullable = false)

case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.LocalDate]),
"daysToLocalDate",
input :: Nil,
returnNullable = false)

case DateType =>
StaticInvoke(
DateTimeUtils.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ object DateTimeUtils {
days.toInt
}

def localDateToDays(localDate: LocalDate): Int = localDate.toEpochDay.toInt

def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days)

/**
* Trim and parse a given UTF8 date string to a corresponding [[Int]] value.
* The return type is [[Option]] in order to distinguish between 0 and null. The following
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1683,11 +1683,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val TIMESTAMP_EXTERNAL_TYPE = buildConf("spark.sql.catalyst.timestampType")
.doc("Java class to/from which an instance of TimestampType is converted.")
.stringConf
.checkValues(Set("Timestamp", "Instant"))
.createWithDefault("Timestamp")
val DATETIME_JAVA8API_EANBLED = buildConf("spark.sql.datetime.java8API.enabled")
.doc("If the configuration property is set to true, java.time.Instant and " +
"java.time.LocalDate classes of Java 8 API are used as external types for " +
"Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " +
"and java.sql.Date are used for the same purpose.")
.booleanConf
.createWithDefault(false)
}

/**
Expand Down Expand Up @@ -1875,7 +1877,7 @@ class SQLConf extends Serializable with Logging {

def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)

def timestampExternalType: String = getConf(TIMESTAMP_EXTERNAL_TYPE)
def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_EANBLED)

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.time.Instant
import java.time.{Instant, LocalDate}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -169,7 +169,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
}

test("converting TimestampType to java.time.Instant") {
withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") {
withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") {
Seq(
-9463427405253013L,
-244000001L,
Expand All @@ -181,4 +181,39 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
}
}
}

test("converting java.time.LocalDate to DateType") {
Seq(
"0101-02-16",
"1582-10-02",
"1582-12-31",
"1970-01-01",
"1972-12-31",
"2019-02-16",
"2119-03-16").foreach { timestamp =>
val input = LocalDate.parse(timestamp)
val result = CatalystTypeConverters.convertToCatalyst(input)
val expected = DateTimeUtils.localDateToDays(input)
assert(result === expected)
}
}

test("converting DateType to java.time.LocalDate") {
withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") {
Seq(
-701265,
-371419,
-199722,
-1,
0,
967,
2094,
17877,
24837,
1110657).foreach { days =>
val localDate = DateTimeUtils.daysToLocalDate(days)
assert(CatalystTypeConverters.createToScalaConverter(DateType)(days) === localDate)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}

test("encoding/decoding TimestampType to/from java.time.Instant") {
withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") {
withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") {
val schema = new StructType().add("t", TimestampType)
val encoder = RowEncoder(schema).resolveAndBind()
val instant = java.time.Instant.parse("2019-02-26T16:56:00Z")
Expand All @@ -294,6 +294,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}

test("encoding/decoding DateType to/from java.time.LocalDate") {
withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") {
val schema = new StructType().add("d", DateType)
val encoder = RowEncoder(schema).resolveAndBind()
val localDate = java.time.LocalDate.parse("2019-02-27")
val row = encoder.toRow(Row(localDate))
assert(row.getLong(0) === DateTimeUtils.localDateToDays(localDate))
val readback = encoder.fromRow(row)
assert(readback.get(0).equals(localDate))
}
}

for {
elementType <- Seq(IntegerType, StringType)
containsNull <- Seq(true, false)
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package test.org.apache.spark.sql;

import java.io.Serializable;
import java.time.LocalDate;
import java.util.List;

import org.apache.spark.sql.internal.SQLConf;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -121,4 +123,20 @@ public void udf6Test() {
Row result = spark.sql("SELECT returnOne()").head();
Assert.assertEquals(1, result.getInt(0));
}

@SuppressWarnings("unchecked")
@Test
public void udf7Test() {
String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_EANBLED().key());
try {
spark.conf().set(SQLConf.DATETIME_JAVA8API_EANBLED().key(), "true");
spark.udf().register(
"plusDay",
(java.time.LocalDate ld) -> ld.plusDays(1), DataTypes.DateType);
Row result = spark.sql("SELECT plusDay(DATE '2019-02-26')").head();
Assert.assertEquals(LocalDate.parse("2019-02-27"), result.get(0));
} finally {
spark.conf().set(SQLConf.DATETIME_JAVA8API_EANBLED().key(), originConf);
}
}
}
12 changes: 11 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,22 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}

test("Using java.time.Instant in UDF") {
withSQLConf(SQLConf.TIMESTAMP_EXTERNAL_TYPE.key -> "Instant") {
withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") {
val expected = java.time.Instant.parse("2019-02-27T00:00:00Z")
val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1))
val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t")
.select(plusSec('t))
assert(df.collect().toSeq === Seq(Row(expected)))
}
}

test("Using java.time.LocalDate in UDF") {
withSQLConf(SQLConf.DATETIME_JAVA8API_EANBLED.key -> "true") {
val expected = java.time.LocalDate.parse("2019-02-27")
val plusDay = udf((i: java.time.LocalDate) => i.plusDays(1))
val df = spark.sql("SELECT DATE '2019-02-26' as d")
.select(plusDay('d))
assert(df.collect().toSeq === Seq(Row(expected)))
}
}
}