diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b0c51b1e9992..874eed436c9e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -661,7 +661,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None): + timestampFormat=None, encoding=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -701,13 +701,15 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + :param encoding: encodes the CSV files by the given encoding type. If None is set, + it uses the default value, ``UTF-8``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat) + dateFormat=dateFormat, timestampFormat=timestampFormat, encoding=encoding) self._jwrite.csv(path) @since(1.5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9c5660a3780a..35d75dd72d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -572,6 +572,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `encoding` (default `UTF-8`): encodes the CSV files by the given encoding + * type.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 6239508ec942..025de5b29276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.{CharArrayWriter, OutputStream, StringReader} -import java.nio.charset.StandardCharsets +import java.io.OutputStream +import java.nio.charset.Charset import com.univocity.parsers.csv._ @@ -71,6 +71,7 @@ private[csv] class LineCsvWriter( output: OutputStream) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat + private val writerCharset = Charset.forName(params.charset) format.setDelimiter(params.delimiter) format.setQuote(params.quote) @@ -84,7 +85,7 @@ private[csv] class LineCsvWriter( writerSettings.setHeaders(headers: _*) writerSettings.setQuoteEscapingEnabled(params.escapeQuotes) - private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings) + private val writer = new CsvWriter(output, writerCharset, writerSettings) def writeRow(row: Seq[String], includeHeader: Boolean): Unit = { if (includeHeader) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 491ff72337a8..fe962cc90e79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -905,4 +905,23 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, Row(1, null)) } } + + test("save data with gb18030") { + withTempPath { path => + // scalastyle:off + val df = Seq(("1", "中文")).toDF("num", "lanaguage") + // scalastyle:on + df.write + .option("header", "true") + .option("encoding", "GB18030") + .csv(path.getAbsolutePath) + + val readBack = spark.read + .option("header", "true") + .option("encoding", "GB18030") + .csv(path.getAbsolutePath) + + checkAnswer(df, readBack) + } + } }