diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 3f28d7ad5051..f86881e70a81 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -283,8 +283,8 @@ def text(self, paths):
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
- negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None,
- maxMalformedLogPerPartition=None, mode=None):
+ negativeInf=None, dateFormat=None, timezone=None, maxColumns=None,
+ maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -328,6 +328,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
applies to both date type and timestamp type. By default, it is None
which means trying to parse times and date by
``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``.
+ :param timezone: defines the timezone to be used for both date type and timestamp type.
+ If a timezone is specified in the data, this will load them after
+ calculating the time difference between both. If None is set, it uses
+ the timezone of your current system.
:param maxColumns: defines a hard limit of how many columns a record can have. If None is
set, it uses the default value, ``20480``.
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
@@ -354,7 +358,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
- dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
+ dateFormat=dateFormat, timezone=timezone, maxColumns=maxColumns,
+ maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
if isinstance(path, basestring):
path = [path]
@@ -631,7 +636,7 @@ def text(self, path, compression=None):
@since(2.0)
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
- header=None, nullValue=None, escapeQuotes=None):
+ header=None, dateFormat=None, timezone=None, nullValue=None, escapeQuotes=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -658,6 +663,13 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
``true``, escaping all values containing a quote character.
:param header: writes the names of columns as the first line. If None is set, it uses
the default value, ``false``.
+ :param dateFormat: sets the string that indicates a date format. Custom date formats
+ follow the formats at ``java.text.SimpleDateFormat``. This applies to
+ both date type and timestamp type. By default, it is None which means
+ writing both as numeric timestamps.
+ :param timezone: defines the timezone to be used with ``dateFormat`` option. If a timezone
+ is specified in ``dateFormat`` (e.g. ``Z``), then it will write the
+ appropriate value with this timezone.
:param nullValue: sets the string representation of a null value. If None is set, it uses
the default value, empty string.
@@ -665,7 +677,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
"""
self.mode(mode)
self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header,
- nullValue=nullValue, escapeQuotes=escapeQuotes)
+ nullValue=nullValue, escapeQuotes=escapeQuotes, dateFormat=dateFormat,
+ timezone=timezone)
self._jwrite.csv(path)
@since(1.5)
@@ -949,8 +962,8 @@ def text(self, path):
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
- negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None,
- maxMalformedLogPerPartition=None, mode=None):
+ negativeInf=None, dateFormat=None, timezone=None, maxColumns=None,
+ maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -996,6 +1009,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
applies to both date type and timestamp type. By default, it is None
which means trying to parse times and date by
``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``.
+ :param timezone: defines the timezone to be used for both date type and timestamp type.
+ If a timezone is specified in the data, this will load them after
+ calculating the time difference between both. If None is set, it uses
+ the timezone of your current system.
:param maxColumns: defines a hard limit of how many columns a record can have. If None is
set, it uses the default value, ``20480``.
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
@@ -1021,7 +1038,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
- dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
+ dateFormat=dateFormat, timezone=timezone, maxColumns=maxColumns,
+ maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 35ba52278633..60fa56414e56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -378,6 +378,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type
* and timestamp type. By default, it is `null` which means trying to parse times and date by
* `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.
+ *
`timezone` (default is the timezone of your current system): defines the timezone to
+ * be used for both date type and timestamp type. If a timezone is specified in the data, this
+ * will load them after calculating the time difference between both.
* `maxColumns` (default `20480`): defines a hard limit of how many columns
* a record can have.
* `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ca3972d62dfb..b73b49bb245e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -537,6 +537,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* quotes should always be enclosed in quotes. Default is to escape all values containing
* a quote character.
* `header` (default `false`): writes the names of columns as the first line.
+ * `dateFormat` (default `null`): sets the string that indicates a date format. Custom date
+ * formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type
+ * and timestamp type. By default, it is `null` which means writing both as numeric
+ * timestamps.
+ * `timezone` (default is the timezone of your current system): defines the timezone to
+ * be used with `dateFormat` option. If a timezone is specified in `dateFormat` (e.g. `Z`),
+ * then it will write the appropriate value with this timezone.
* `nullValue` (default empty string): sets the string representation of a null value.
* `compression` (default `null`): compression codec to use when saving to file. This can be
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 581eda7e09a3..b6ec52e018fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.StandardCharsets
import java.text.SimpleDateFormat
+import java.util.TimeZone
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes}
@@ -101,10 +102,16 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
name.map(CompressionCodecs.getCodecClassName)
}
+ private val timezone: Option[String] = parameters.get("timezone")
+
// Share date format object as it is expensive to parse date pattern.
val dateFormat: SimpleDateFormat = {
val dateFormat = parameters.get("dateFormat")
- dateFormat.map(new SimpleDateFormat(_)).orNull
+ dateFormat.map { f =>
+ val format = new SimpleDateFormat(f)
+ timezone.foreach(tz => format.setTimeZone(TimeZone.getTimeZone(tz)))
+ format
+ }.orNull
}
val maxColumns = getInt("maxColumns", 20480)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index e8c0134d3880..4d7388f2cb79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.types._
@@ -179,6 +180,12 @@ private[sql] class CsvOutputWriter(
// create the Generator without separator inserted between 2 records
private[this] val text = new Text()
+ // A `ValueConverter` is responsible for converting a field of an `InternalRow` to `String`.
+ private type ValueConverter = (InternalRow, Int) => String
+
+ // `ValueConverter`s for all fields of the schema
+ private val fieldsConverters: Seq[ValueConverter] = dataSchema.map(_.dataType).map(makeConverter)
+
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
@@ -195,18 +202,40 @@ private[sql] class CsvOutputWriter(
private var records: Long = 0L
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
- private def rowToString(row: Seq[Any]): Seq[String] = row.map { field =>
- if (field != null) {
- field.toString
- } else {
- params.nullValue
+ private def rowToString(row: InternalRow): Seq[String] = {
+ var i = 0
+ val values = new Array[String](row.numFields)
+ while (i < row.numFields) {
+ if (!row.isNullAt(i)) {
+ values(i) = fieldsConverters(i).apply(row, i)
+ } else {
+ values(i) = params.nullValue
+ }
+ i += 1
}
+ values
+ }
+
+ private def makeConverter(dataType: DataType): ValueConverter = dataType match {
+ case DateType if params.dateFormat != null =>
+ (row: InternalRow, ordinal: Int) =>
+ params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
+
+ case TimestampType if params.dateFormat != null =>
+ (row: InternalRow, ordinal: Int) =>
+ params.dateFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
+
+ case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
+
+ case dt: DataType =>
+ (row: InternalRow, ordinal: Int) =>
+ row.get(ordinal, dt).toString
}
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
- csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), records == 0L && params.headerFlag)
+ csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag)
records += 1
if (records % FLUSH_BATCH_SIZE == 0) {
flush()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index f170065132ac..646f0948cb6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -665,4 +665,121 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true, checkValues = false)
}
+
+ test("Write timestamps correctly with dateFormat and timezone option") {
+ withTempDir { dir =>
+ // With dateFormat option.
+ val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.csv"
+ val datesWithFormat = spark.read
+ .format("csv")
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .option("dateFormat", "dd/MM/yyyy HH:mm")
+ .load(testFile(datesFile))
+ datesWithFormat.write
+ .format("csv")
+ .option("header", "true")
+ .option("dateFormat", "yyyy/MM/dd HH:mm")
+ .save(datesWithFormatPath)
+
+ // This will load back the timestamps as string.
+ val stringDatesWithFormat = spark.read
+ .format("csv")
+ .option("header", "true")
+ .option("inferSchema", "false")
+ .load(datesWithFormatPath)
+ val expectedStringDatesWithFormat = Seq(
+ Row("2015/08/26 18:00"),
+ Row("2014/10/27 18:30"),
+ Row("2016/01/28 20:00"))
+
+ checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat)
+
+ // With dateFormat and timezone option.
+ val datesWithZoneAndFormatPath = s"${dir.getCanonicalPath}/datesWithZoneAndFormat.csv"
+ val datesWithZoneAndFormat = spark.read
+ .format("csv")
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .option("timezone", "GMT")
+ .option("dateFormat", "dd/MM/yyyy HH:mm")
+ .load(testFile(datesFile))
+ datesWithZoneAndFormat.write
+ .format("csv")
+ .option("header", "true")
+ .option("timezone", "Asia/Seoul")
+ .option("dateFormat", "dd/MM/yyyy HH:mmZ")
+ .save(datesWithZoneAndFormatPath)
+
+ // This will load back the timestamps as string.
+ val stringDates = spark.read
+ .format("csv")
+ .option("header", "true")
+ .load(datesWithZoneAndFormatPath)
+ val expectedStringDates = Seq(
+ Row("27/08/2015 03:00+0900"),
+ Row("28/10/2014 03:30+0900"),
+ Row("29/01/2016 05:00+0900"))
+
+ checkAnswer(stringDates, expectedStringDates)
+ }
+ }
+
+ test("Write dates correctly with dateFormat and timezone option") {
+ val customSchema = new StructType(Array(StructField("date", DateType, true)))
+ withTempDir { dir =>
+ // With dateFormat option.
+ val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.csv"
+ val datesWithFormat = spark.read
+ .format("csv")
+ .schema(customSchema)
+ .option("header", "true")
+ .option("dateFormat", "dd/MM/yyyy HH:mm")
+ .load(testFile(datesFile))
+ datesWithFormat.write
+ .format("csv")
+ .option("header", "true")
+ .option("dateFormat", "yyyy/MM/dd")
+ .save(datesWithFormatPath)
+
+ // This will load back the dates as string.
+ val stringDatesWithFormat = spark.read
+ .format("csv")
+ .option("header", "true")
+ .load(datesWithFormatPath)
+ val expectedStringDatesWithFormat = Seq(
+ Row("2015/08/26"),
+ Row("2014/10/27"),
+ Row("2016/01/28"))
+
+ checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat)
+
+ // With dateFormat and timezone option.
+ val datesWithZoneAndFormatPath = s"${dir.getCanonicalPath}/datesWithZoneAndFormat.csv"
+ val datesWithZoneAndFormat = spark.read
+ .format("csv")
+ .schema(customSchema)
+ .option("header", "true")
+ .option("dateFormat", "dd/MM/yyyy HH:mm")
+ .load(testFile(datesFile))
+ datesWithZoneAndFormat.write
+ .format("csv")
+ .option("header", "true")
+ .option("timezone", "GTM")
+ .option("dateFormat", "dd/MM/yyyy z")
+ .save(datesWithZoneAndFormatPath)
+
+ // This will load back the dates as string.
+ val stringDates = spark.read
+ .format("csv")
+ .option("header", "true")
+ .load(datesWithZoneAndFormatPath)
+ val expectedStringDates = Seq(
+ Row("26/08/2015 GMT"),
+ Row("27/10/2014 GMT"),
+ Row("28/01/2016 GMT"))
+
+ checkAnswer(stringDates, expectedStringDates)
+ }
+ }
}