diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3f28d7ad5051..f86881e70a81 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -283,8 +283,8 @@ def text(self, paths): def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, - negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, - maxMalformedLogPerPartition=None, mode=None): + negativeInf=None, dateFormat=None, timezone=None, maxColumns=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -328,6 +328,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non applies to both date type and timestamp type. By default, it is None which means trying to parse times and date by ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``. + :param timezone: defines the timezone to be used for both date type and timestamp type. + If a timezone is specified in the data, this will load them after + calculating the time difference between both. If None is set, it uses + the timezone of your current system. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given @@ -354,7 +358,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, - dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, + dateFormat=dateFormat, timezone=timezone, maxColumns=maxColumns, + maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode) if isinstance(path, basestring): path = [path] @@ -631,7 +636,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, - header=None, nullValue=None, escapeQuotes=None): + header=None, dateFormat=None, timezone=None, nullValue=None, escapeQuotes=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -658,6 +663,13 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ``true``, escaping all values containing a quote character. :param header: writes the names of columns as the first line. If None is set, it uses the default value, ``false``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This applies to + both date type and timestamp type. By default, it is None which means + writing both as numeric timestamps. + :param timezone: defines the timezone to be used with ``dateFormat`` option. If a timezone + is specified in ``dateFormat`` (e.g. ``Z``), then it will write the + appropriate value with this timezone. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. @@ -665,7 +677,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, - nullValue=nullValue, escapeQuotes=escapeQuotes) + nullValue=nullValue, escapeQuotes=escapeQuotes, dateFormat=dateFormat, + timezone=timezone) self._jwrite.csv(path) @since(1.5) @@ -949,8 +962,8 @@ def text(self, path): def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, - negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, - maxMalformedLogPerPartition=None, mode=None): + negativeInf=None, dateFormat=None, timezone=None, maxColumns=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -996,6 +1009,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non applies to both date type and timestamp type. By default, it is None which means trying to parse times and date by ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``. + :param timezone: defines the timezone to be used for both date type and timestamp type. + If a timezone is specified in the data, this will load them after + calculating the time difference between both. If None is set, it uses + the timezone of your current system. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given @@ -1021,7 +1038,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, - dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, + dateFormat=dateFormat, timezone=timezone, maxColumns=maxColumns, + maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 35ba52278633..60fa56414e56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -378,6 +378,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type * and timestamp type. By default, it is `null` which means trying to parse times and date by * `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`. + *
  • `timezone` (default is the timezone of your current system): defines the timezone to + * be used for both date type and timestamp type. If a timezone is specified in the data, this + * will load them after calculating the time difference between both.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ca3972d62dfb..b73b49bb245e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -537,6 +537,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * quotes should always be enclosed in quotes. Default is to escape all values containing * a quote character.
  • *
  • `header` (default `false`): writes the names of columns as the first line.
  • + *
  • `dateFormat` (default `null`): sets the string that indicates a date format. Custom date + * formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type + * and timestamp type. By default, it is `null` which means writing both as numeric + * timestamps.
  • + *
  • `timezone` (default is the timezone of your current system): defines the timezone to + * be used with `dateFormat` option. If a timezone is specified in `dateFormat` (e.g. `Z`), + * then it will write the appropriate value with this timezone.
  • *
  • `nullValue` (default empty string): sets the string representation of a null value.
  • *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 581eda7e09a3..b6ec52e018fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat +import java.util.TimeZone import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} @@ -101,10 +102,16 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str name.map(CompressionCodecs.getCodecClassName) } + private val timezone: Option[String] = parameters.get("timezone") + // Share date format object as it is expensive to parse date pattern. val dateFormat: SimpleDateFormat = { val dateFormat = parameters.get("dateFormat") - dateFormat.map(new SimpleDateFormat(_)).orNull + dateFormat.map { f => + val format = new SimpleDateFormat(f) + timezone.foreach(tz => format.setTimeZone(TimeZone.getTimeZone(tz))) + format + }.orNull } val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index e8c0134d3880..4d7388f2cb79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.types._ @@ -179,6 +180,12 @@ private[sql] class CsvOutputWriter( // create the Generator without separator inserted between 2 records private[this] val text = new Text() + // A `ValueConverter` is responsible for converting a field of an `InternalRow` to `String`. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all fields of the schema + private val fieldsConverters: Seq[ValueConverter] = dataSchema.map(_.dataType).map(makeConverter) + private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -195,18 +202,40 @@ private[sql] class CsvOutputWriter( private var records: Long = 0L private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) - private def rowToString(row: Seq[Any]): Seq[String] = row.map { field => - if (field != null) { - field.toString - } else { - params.nullValue + private def rowToString(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = fieldsConverters(i).apply(row, i) + } else { + values(i) = params.nullValue + } + i += 1 } + values + } + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType if params.dateFormat != null => + (row: InternalRow, ordinal: Int) => + params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType if params.dateFormat != null => + (row: InternalRow, ordinal: Int) => + params.dateFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString } override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), records == 0L && params.headerFlag) + csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag) records += 1 if (records % FLUSH_BATCH_SIZE == 0) { flush() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index f170065132ac..646f0948cb6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -665,4 +665,121 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkValues = false) } + + test("Write timestamps correctly with dateFormat and timezone option") { + withTempDir { dir => + // With dateFormat option. + val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.csv" + val datesWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + datesWithFormat.write + .format("csv") + .option("header", "true") + .option("dateFormat", "yyyy/MM/dd HH:mm") + .save(datesWithFormatPath) + + // This will load back the timestamps as string. + val stringDatesWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(datesWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26 18:00"), + Row("2014/10/27 18:30"), + Row("2016/01/28 20:00")) + + checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat) + + // With dateFormat and timezone option. + val datesWithZoneAndFormatPath = s"${dir.getCanonicalPath}/datesWithZoneAndFormat.csv" + val datesWithZoneAndFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .option("timezone", "GMT") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + datesWithZoneAndFormat.write + .format("csv") + .option("header", "true") + .option("timezone", "Asia/Seoul") + .option("dateFormat", "dd/MM/yyyy HH:mmZ") + .save(datesWithZoneAndFormatPath) + + // This will load back the timestamps as string. + val stringDates = spark.read + .format("csv") + .option("header", "true") + .load(datesWithZoneAndFormatPath) + val expectedStringDates = Seq( + Row("27/08/2015 03:00+0900"), + Row("28/10/2014 03:30+0900"), + Row("29/01/2016 05:00+0900")) + + checkAnswer(stringDates, expectedStringDates) + } + } + + test("Write dates correctly with dateFormat and timezone option") { + val customSchema = new StructType(Array(StructField("date", DateType, true))) + withTempDir { dir => + // With dateFormat option. + val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.csv" + val datesWithFormat = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + datesWithFormat.write + .format("csv") + .option("header", "true") + .option("dateFormat", "yyyy/MM/dd") + .save(datesWithFormatPath) + + // This will load back the dates as string. + val stringDatesWithFormat = spark.read + .format("csv") + .option("header", "true") + .load(datesWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26"), + Row("2014/10/27"), + Row("2016/01/28")) + + checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat) + + // With dateFormat and timezone option. + val datesWithZoneAndFormatPath = s"${dir.getCanonicalPath}/datesWithZoneAndFormat.csv" + val datesWithZoneAndFormat = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + datesWithZoneAndFormat.write + .format("csv") + .option("header", "true") + .option("timezone", "GTM") + .option("dateFormat", "dd/MM/yyyy z") + .save(datesWithZoneAndFormatPath) + + // This will load back the dates as string. + val stringDates = spark.read + .format("csv") + .option("header", "true") + .load(datesWithZoneAndFormatPath) + val expectedStringDates = Seq( + Row("26/08/2015 GMT"), + Row("27/10/2014 GMT"), + Row("28/01/2016 GMT")) + + checkAnswer(stringDates, expectedStringDates) + } + } }