diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index f54c6920ce82..055fbc49bdcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -149,7 +149,7 @@ object Encoders { * - boxed types: Boolean, Integer, Double, etc. * - String * - java.math.BigDecimal, java.math.BigInteger - * - time related: java.sql.Date, java.sql.Timestamp + * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4f5af9ac80b1..f13eddee77e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -269,6 +269,13 @@ trait Row extends Serializable { */ def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) + /** + * Returns the value at position i of date type as java.time.LocalDate. + * + * @throws ClassCastException when data type does not match. + */ + def getLocalDate(i: Int): java.time.LocalDate = getAs[java.time.LocalDate](i) + /** * Returns the value at position i of date type as java.sql.Timestamp. * @@ -276,6 +283,13 @@ trait Row extends Serializable { */ def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) + /** + * Returns the value at position i of date type as java.time.Instant. + * + * @throws ClassCastException when data type does not match. + */ + def getInstant(i: Int): java.time.Instant = getAs[java.time.Instant](i) + /** * Returns the value at position i of array type as a Scala Seq. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 39132139237c..c5be3efc6371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -102,7 +102,9 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true) + case c: Class[_] if c == classOf[java.time.LocalDate] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case _ if typeToken.isArray => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index c5f38676ad0a..7bf0789b43d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -18,10 +18,15 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.time.Instant; +import java.time.LocalDate; import java.util.*; import org.apache.spark.sql.*; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.TimestampFormatter; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.junit.*; @@ -509,4 +514,95 @@ public void setId(Integer id) { this.id = id; } } + + @Test + public void testBeanWithLocalDateAndInstant() { + String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key()); + try { + spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true"); + List inputRows = new ArrayList<>(); + List expectedRecords = new ArrayList<>(); + + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createLocalDateInstantRow(idx); + inputRows.add(row); + expectedRecords.add(createLocalDateInstantRecord(row)); + } + + Encoder encoder = Encoders.bean(LocalDateInstantRecord.class); + + StructType schema = new StructType() + .add("localDateField", DataTypes.DateType) + .add("instantField", DataTypes.TimestampType); + + Dataset dataFrame = spark.createDataFrame(inputRows, schema); + Dataset dataset = dataFrame.as(encoder); + + List records = dataset.collectAsList(); + + Assert.assertEquals(expectedRecords, records); + } finally { + spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf); + } + } + + public static final class LocalDateInstantRecord { + private String localDateField; + private String instantField; + + public LocalDateInstantRecord() { } + + public String getLocalDateField() { + return localDateField; + } + + public void setLocalDateField(String localDateField) { + this.localDateField = localDateField; + } + + public String getInstantField() { + return instantField; + } + + public void setInstantField(String instantField) { + this.instantField = instantField; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LocalDateInstantRecord that = (LocalDateInstantRecord) o; + return Objects.equals(localDateField, that.localDateField) && + Objects.equals(instantField, that.instantField); + } + + @Override + public int hashCode() { + return Objects.hash(localDateField, instantField); + } + + @Override + public String toString() { + return com.google.common.base.Objects.toStringHelper(this) + .add("localDateField", localDateField) + .add("instantField", instantField) + .toString(); + } + } + + private static Row createLocalDateInstantRow(Long index) { + Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) }; + return new GenericRow(values); + } + + private static LocalDateInstantRecord createLocalDateInstantRecord(Row recordRow) { + LocalDateInstantRecord record = new LocalDateInstantRecord(); + record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0))); + Instant instant = recordRow.getInstant(1); + TimestampFormatter formatter = TimestampFormatter.getFractionFormatter( + DateTimeUtils.getZoneId(SQLConf.get().sessionLocalTimeZone())); + record.setInstantField(formatter.format(DateTimeUtils.instantToMicros(instant))); + return record; + } }